TensorRT-LLMs/tests/model/redrafter/test_prepare_input.py
2024-08-29 17:25:07 +08:00

174 lines
6.6 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import unittest
import torch
from parameterized import parameterized
import tensorrt_llm
import tensorrt_llm.models.redrafter
import tensorrt_llm.models.redrafter.redrafter_helper
from tensorrt_llm import Tensor
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
from utils.util import create_session, run_session
REFS_0 = torch.tensor([3, 2, 1, 4], dtype=torch.int32, device="cuda")
REFS_1 = torch.tensor([[3, 2, 4, 1], [1, 8, 1, 3], [1, 7, 6, 4], [7, 8, 8, 4]],
dtype=torch.int32,
device="cuda")
REFS_2 = torch.tensor(
[[[5, 4, 3, 7], [7, 7, 9, 6], [7, 8, 8, 4], [0, 2, 2, 2]],
[[1, 5, 5, 0], [5, 7, 7, 5], [9, 4, 7, 4], [1, 0, 0, 8]],
[[4, 8, 0, 8], [3, 4, 0, 2], [0, 9, 1, 3], [5, 6, 5, 2]],
[[7, 6, 7, 5], [9, 7, 8, 1], [6, 8, 9, 0], [6, 1, 1, 2]]],
dtype=torch.int32,
device="cuda")
class TestReDrafter(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('warning')
########################################################################################################################
@parameterized.expand([
((4, 4), REFS_0),
((4, 4, 4), REFS_1),
((4, 4, 4, 4), REFS_2),
])
def test_batch_index_select(self, shape, ref_res) -> None:
old_device = torch.get_default_device()
torch.set_default_device("cuda")
torch.manual_seed(7)
x_data = torch.randint(10, size=shape, dtype=torch.int32)
indices = torch.randint(shape[1], size=(shape[0], ), dtype=torch.int32)
# construct trt network
builder = tensorrt_llm.Builder()
network = builder.create_network()
with tensorrt_llm.net_guard(network):
x_trt = Tensor(name="x",
shape=x_data.shape,
dtype=tensorrt_llm.str_dtype_to_trt("int32"))
indices_trt = Tensor(name="indices",
shape=indices.shape,
dtype=tensorrt_llm.str_dtype_to_trt("int32"))
output = tensorrt_llm.models.redrafter.redrafter_helper._batch_index_select(
x_trt, indices_trt)
output.mark_output("output")
# trt run
session = create_session(builder, network, precision='float32')
inputs = {
"x": x_data,
"indices": indices,
}
outputs = run_session(session, inputs)
# compare diff
torch.testing.assert_close(ref_res, outputs["output"])
torch.set_default_device(old_device)
return
########################################################################################################################
def test_prepare_next_input(self) -> None:
old_device = torch.get_default_device()
torch.set_default_device("cuda")
torch.manual_seed(17)
# test data
batch_size, num_candidates, candidate_len, vocab_size, hidden_size = 2, 4, 4, 1, 1
draft_log_probs = torch.rand(
[batch_size, num_candidates, candidate_len, vocab_size],
dtype=torch.float32)
base_log_probs = torch.rand(
[batch_size, num_candidates, candidate_len, vocab_size],
dtype=torch.float32)
last_base_log_probs = torch.rand(
[batch_size, num_candidates, vocab_size], dtype=torch.float32)
beam_index = torch.randint(0,
num_candidates, (batch_size, ),
dtype=torch.int32)
num_accept_tokens = torch.randint(0,
candidate_len, (batch_size, ),
dtype=torch.int32)
# construct trt network
builder = tensorrt_llm.Builder()
network = builder.create_network()
with tensorrt_llm.net_guard(network):
draft_log_probs_trt = Tensor(
name="draft_log_probs",
shape=draft_log_probs.shape,
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
)
base_log_probs_trt = Tensor(
name="base_log_probs",
shape=base_log_probs.shape,
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
)
last_base_log_probs_trt = Tensor(
name="last_base_log_probs",
shape=last_base_log_probs.shape,
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
)
beam_index_trt = Tensor(
name="beam_index",
shape=beam_index.shape,
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
)
num_accept_tokens_trt = Tensor(
name="num_accept_tokens",
shape=num_accept_tokens.shape,
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
)
probs = tensorrt_llm.models.redrafter.redrafter_helper._prepare_drafter_input(
draft_log_probs_trt,
base_log_probs_trt,
last_base_log_probs_trt,
beam_index_trt,
num_accept_tokens_trt,
)
probs.mark_output("probs")
# trt run
session = create_session(builder, network, precision='float32')
inputs = {
"draft_log_probs": draft_log_probs,
"base_log_probs": base_log_probs,
"last_base_log_probs": last_base_log_probs,
"beam_index": beam_index,
"num_accept_tokens": num_accept_tokens,
}
outputs = run_session(session, inputs)
ref_probs = torch.tensor([[0.1245], [0.3713]], dtype=torch.float32)
# compare diff
torch.testing.assert_close(ref_probs,
outputs["probs"],
atol=1e-4,
rtol=0.1)
torch.set_default_device(old_device)
return