TensorRT-LLMs/tests/model/redrafter/test_unpack_gen_data.py
2024-08-29 17:25:07 +08:00

153 lines
6.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import torch
from parameterized import parameterized
import tensorrt_llm
import tensorrt_llm.models.redrafter
import tensorrt_llm.models.redrafter.redrafter_helper
from tensorrt_llm import Tensor
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
from utils.util import create_session, run_session, set_input_shape
class TestReDrafter(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('warning')
########################################################################################################################
@parameterized.expand([
([3], [[[0.3990, 0.5167, 0.0249, 0.9401],
[0.9459, 0.7967, 0.4150, 0.8203],
[0.2290, 0.9096, 0.1183, 0.0752]]]),
([3, 4], [[[0.3990, 0.5167, 0.0249, 0.9401],
[0.9459, 0.7967, 0.4150, 0.8203],
[0.2290, 0.9096, 0.1183, 0.0752],
[0.4092, 0.9601, 0.2093, 0.1940]],
[[0.4092, 0.9601, 0.2093, 0.1940],
[0.8909, 0.4387, 0.3570, 0.5454],
[0.8299, 0.2099, 0.7684, 0.4290],
[0.2117, 0.6606, 0.1654, 0.4250]]]),
([4, 3], [[[0.3990, 0.5167, 0.0249, 0.9401],
[0.9459, 0.7967, 0.4150, 0.8203],
[0.2290, 0.9096, 0.1183, 0.0752],
[0.4092, 0.9601, 0.2093, 0.1940]],
[[0.8909, 0.4387, 0.3570, 0.5454],
[0.8299, 0.2099, 0.7684, 0.4290],
[0.2117, 0.6606, 0.1654, 0.4250],
[0.2117, 0.6606, 0.1654, 0.4250]]]),
([3, 5, 1], [[[0.3990, 0.5167, 0.0249, 0.9401],
[0.9459, 0.7967, 0.4150, 0.8203],
[0.2290, 0.9096, 0.1183, 0.0752],
[0.4092, 0.9601, 0.2093, 0.1940],
[0.8909, 0.4387, 0.3570, 0.5454]],
[[0.4092, 0.9601, 0.2093, 0.1940],
[0.8909, 0.4387, 0.3570, 0.5454],
[0.8299, 0.2099, 0.7684, 0.4290],
[0.2117, 0.6606, 0.1654, 0.4250],
[0.9927, 0.6964, 0.2472, 0.7028]],
[[0.7494, 0.9303, 0.0494, 0.0750],
[0.7494, 0.9303, 0.0494, 0.0750],
[0.7494, 0.9303, 0.0494, 0.0750],
[0.7494, 0.9303, 0.0494, 0.0750],
[0.7494, 0.9303, 0.0494, 0.0750]]]),
])
def test_unpack_gen_data(self,
num_gen_tokens=[3],
ref_res=[[[0.3990, 0.5167, 0.0249, 0.9401],
[0.9459, 0.7967, 0.4150, 0.8203],
[0.2290, 0.9096, 0.1183, 0.0752]]]):
# test data
V = 4
nb = 3
bl = 4
old_device = torch.get_default_device()
torch.set_default_device("cuda")
torch.manual_seed(0)
num_gen_tokens = torch.tensor(num_gen_tokens, dtype=torch.int32)
ref_res = torch.tensor(ref_res, dtype=torch.float32)
assert torch.any(num_gen_tokens <= (nb * (bl - 1) + 1))
total_tokens = num_gen_tokens.sum()
max_gen_token = num_gen_tokens.max().cpu()
lm_logits = torch.rand((total_tokens, V), dtype=torch.float32)
gen_unpack_indxs = torch.arange(max_gen_token, dtype=torch.int32)
gen_unpack_indxs = gen_unpack_indxs.unsqueeze(0) + (
torch.cumsum(num_gen_tokens, dim=0) - num_gen_tokens).unsqueeze(1)
gen_unpack_indxs = torch.minimum(gen_unpack_indxs, total_tokens - 1)
# construct trt network
builder = tensorrt_llm.Builder()
network = builder.create_network()
with tensorrt_llm.net_guard(network):
lm_logits_t = Tensor(name='l',
shape=lm_logits.shape,
dtype=tensorrt_llm.torch_dtype_to_trt(
lm_logits.dtype))
num_gen_tokens_t = Tensor(name='ng',
shape=num_gen_tokens.shape,
dtype=tensorrt_llm.torch_dtype_to_trt(
num_gen_tokens.dtype))
gen_unpack_indxs_t = Tensor(name='gui',
shape=gen_unpack_indxs.shape,
dtype=tensorrt_llm.torch_dtype_to_trt(
gen_unpack_indxs.dtype))
max_gen_token_t = Tensor(name='mgt',
shape=max_gen_token.shape,
dtype=tensorrt_llm.torch_dtype_to_trt(
max_gen_token.dtype))
outputs = tensorrt_llm.models.redrafter.redrafter_helper._unpack_gen_data(
lm_logits_t, num_gen_tokens_t, gen_unpack_indxs_t,
max_gen_token_t)
outputs.mark_output('res')
# save onnx
# model_path = 'unpack_gen.onnx'
# to_onnx(net.trt_network, model_path)
# needs profile for dynamic shape
profile = builder.trt_builder.create_optimization_profile()
set_input_shape(profile, lm_logits_t, lm_logits.shape, lm_logits)
set_input_shape(profile, num_gen_tokens_t, num_gen_tokens.shape,
num_gen_tokens)
set_input_shape(profile, gen_unpack_indxs_t, gen_unpack_indxs.shape,
gen_unpack_indxs)
set_input_shape(profile, max_gen_token_t, max_gen_token.shape,
max_gen_token)
# trt run
session = create_session(builder,
network,
precision='float32',
optimization_profiles=[profile])
inputs = {
'l': lm_logits,
'ng': num_gen_tokens,
'gui': gen_unpack_indxs,
'mgt': max_gen_token,
}
outputs = run_session(session, inputs)
# print(outputs)
torch.testing.assert_close(outputs['res'], ref_res, atol=0.01, rtol=0.1)
torch.set_default_device(old_device)
return