mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-26 05:32:57 +08:00
153 lines
6.8 KiB
Python
153 lines
6.8 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.models.redrafter
|
|
import tensorrt_llm.models.redrafter.redrafter_helper
|
|
from tensorrt_llm import Tensor
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
|
|
from utils.util import create_session, run_session, set_input_shape
|
|
|
|
|
|
class TestReDrafter(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('warning')
|
|
|
|
|
|
########################################################################################################################
|
|
|
|
@parameterized.expand([
|
|
([3], [[[0.3990, 0.5167, 0.0249, 0.9401],
|
|
[0.9459, 0.7967, 0.4150, 0.8203],
|
|
[0.2290, 0.9096, 0.1183, 0.0752]]]),
|
|
([3, 4], [[[0.3990, 0.5167, 0.0249, 0.9401],
|
|
[0.9459, 0.7967, 0.4150, 0.8203],
|
|
[0.2290, 0.9096, 0.1183, 0.0752],
|
|
[0.4092, 0.9601, 0.2093, 0.1940]],
|
|
[[0.4092, 0.9601, 0.2093, 0.1940],
|
|
[0.8909, 0.4387, 0.3570, 0.5454],
|
|
[0.8299, 0.2099, 0.7684, 0.4290],
|
|
[0.2117, 0.6606, 0.1654, 0.4250]]]),
|
|
([4, 3], [[[0.3990, 0.5167, 0.0249, 0.9401],
|
|
[0.9459, 0.7967, 0.4150, 0.8203],
|
|
[0.2290, 0.9096, 0.1183, 0.0752],
|
|
[0.4092, 0.9601, 0.2093, 0.1940]],
|
|
[[0.8909, 0.4387, 0.3570, 0.5454],
|
|
[0.8299, 0.2099, 0.7684, 0.4290],
|
|
[0.2117, 0.6606, 0.1654, 0.4250],
|
|
[0.2117, 0.6606, 0.1654, 0.4250]]]),
|
|
([3, 5, 1], [[[0.3990, 0.5167, 0.0249, 0.9401],
|
|
[0.9459, 0.7967, 0.4150, 0.8203],
|
|
[0.2290, 0.9096, 0.1183, 0.0752],
|
|
[0.4092, 0.9601, 0.2093, 0.1940],
|
|
[0.8909, 0.4387, 0.3570, 0.5454]],
|
|
[[0.4092, 0.9601, 0.2093, 0.1940],
|
|
[0.8909, 0.4387, 0.3570, 0.5454],
|
|
[0.8299, 0.2099, 0.7684, 0.4290],
|
|
[0.2117, 0.6606, 0.1654, 0.4250],
|
|
[0.9927, 0.6964, 0.2472, 0.7028]],
|
|
[[0.7494, 0.9303, 0.0494, 0.0750],
|
|
[0.7494, 0.9303, 0.0494, 0.0750],
|
|
[0.7494, 0.9303, 0.0494, 0.0750],
|
|
[0.7494, 0.9303, 0.0494, 0.0750],
|
|
[0.7494, 0.9303, 0.0494, 0.0750]]]),
|
|
])
|
|
def test_unpack_gen_data(self,
|
|
num_gen_tokens=[3],
|
|
ref_res=[[[0.3990, 0.5167, 0.0249, 0.9401],
|
|
[0.9459, 0.7967, 0.4150, 0.8203],
|
|
[0.2290, 0.9096, 0.1183, 0.0752]]]):
|
|
# test data
|
|
V = 4
|
|
nb = 3
|
|
bl = 4
|
|
old_device = torch.get_default_device()
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(0)
|
|
num_gen_tokens = torch.tensor(num_gen_tokens, dtype=torch.int32)
|
|
ref_res = torch.tensor(ref_res, dtype=torch.float32)
|
|
assert torch.any(num_gen_tokens <= (nb * (bl - 1) + 1))
|
|
total_tokens = num_gen_tokens.sum()
|
|
max_gen_token = num_gen_tokens.max().cpu()
|
|
lm_logits = torch.rand((total_tokens, V), dtype=torch.float32)
|
|
gen_unpack_indxs = torch.arange(max_gen_token, dtype=torch.int32)
|
|
gen_unpack_indxs = gen_unpack_indxs.unsqueeze(0) + (
|
|
torch.cumsum(num_gen_tokens, dim=0) - num_gen_tokens).unsqueeze(1)
|
|
gen_unpack_indxs = torch.minimum(gen_unpack_indxs, total_tokens - 1)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
lm_logits_t = Tensor(name='l',
|
|
shape=lm_logits.shape,
|
|
dtype=tensorrt_llm.torch_dtype_to_trt(
|
|
lm_logits.dtype))
|
|
num_gen_tokens_t = Tensor(name='ng',
|
|
shape=num_gen_tokens.shape,
|
|
dtype=tensorrt_llm.torch_dtype_to_trt(
|
|
num_gen_tokens.dtype))
|
|
gen_unpack_indxs_t = Tensor(name='gui',
|
|
shape=gen_unpack_indxs.shape,
|
|
dtype=tensorrt_llm.torch_dtype_to_trt(
|
|
gen_unpack_indxs.dtype))
|
|
max_gen_token_t = Tensor(name='mgt',
|
|
shape=max_gen_token.shape,
|
|
dtype=tensorrt_llm.torch_dtype_to_trt(
|
|
max_gen_token.dtype))
|
|
|
|
outputs = tensorrt_llm.models.redrafter.redrafter_helper._unpack_gen_data(
|
|
lm_logits_t, num_gen_tokens_t, gen_unpack_indxs_t,
|
|
max_gen_token_t)
|
|
outputs.mark_output('res')
|
|
# save onnx
|
|
# model_path = 'unpack_gen.onnx'
|
|
# to_onnx(net.trt_network, model_path)
|
|
|
|
# needs profile for dynamic shape
|
|
profile = builder.trt_builder.create_optimization_profile()
|
|
set_input_shape(profile, lm_logits_t, lm_logits.shape, lm_logits)
|
|
set_input_shape(profile, num_gen_tokens_t, num_gen_tokens.shape,
|
|
num_gen_tokens)
|
|
set_input_shape(profile, gen_unpack_indxs_t, gen_unpack_indxs.shape,
|
|
gen_unpack_indxs)
|
|
set_input_shape(profile, max_gen_token_t, max_gen_token.shape,
|
|
max_gen_token)
|
|
|
|
# trt run
|
|
session = create_session(builder,
|
|
network,
|
|
precision='float32',
|
|
optimization_profiles=[profile])
|
|
inputs = {
|
|
'l': lm_logits,
|
|
'ng': num_gen_tokens,
|
|
'gui': gen_unpack_indxs,
|
|
'mgt': max_gen_token,
|
|
}
|
|
outputs = run_session(session, inputs)
|
|
# print(outputs)
|
|
torch.testing.assert_close(outputs['res'], ref_res, atol=0.01, rtol=0.1)
|
|
torch.set_default_device(old_device)
|
|
return
|