mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-23 04:03:22 +08:00
174 lines
6.6 KiB
Python
174 lines
6.6 KiB
Python
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
import tensorrt_llm.models.redrafter
|
|
import tensorrt_llm.models.redrafter.redrafter_helper
|
|
from tensorrt_llm import Tensor
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
|
|
from utils.util import create_session, run_session
|
|
|
|
REFS_0 = torch.tensor([3, 2, 1, 4], dtype=torch.int32, device="cuda")
|
|
|
|
REFS_1 = torch.tensor([[3, 2, 4, 1], [1, 8, 1, 3], [1, 7, 6, 4], [7, 8, 8, 4]],
|
|
dtype=torch.int32,
|
|
device="cuda")
|
|
|
|
REFS_2 = torch.tensor(
|
|
[[[5, 4, 3, 7], [7, 7, 9, 6], [7, 8, 8, 4], [0, 2, 2, 2]],
|
|
[[1, 5, 5, 0], [5, 7, 7, 5], [9, 4, 7, 4], [1, 0, 0, 8]],
|
|
[[4, 8, 0, 8], [3, 4, 0, 2], [0, 9, 1, 3], [5, 6, 5, 2]],
|
|
[[7, 6, 7, 5], [9, 7, 8, 1], [6, 8, 9, 0], [6, 1, 1, 2]]],
|
|
dtype=torch.int32,
|
|
device="cuda")
|
|
|
|
|
|
class TestReDrafter(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('warning')
|
|
|
|
########################################################################################################################
|
|
|
|
@parameterized.expand([
|
|
((4, 4), REFS_0),
|
|
((4, 4, 4), REFS_1),
|
|
((4, 4, 4, 4), REFS_2),
|
|
])
|
|
def test_batch_index_select(self, shape, ref_res) -> None:
|
|
old_device = torch.get_default_device()
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(7)
|
|
x_data = torch.randint(10, size=shape, dtype=torch.int32)
|
|
indices = torch.randint(shape[1], size=(shape[0], ), dtype=torch.int32)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
x_trt = Tensor(name="x",
|
|
shape=x_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"))
|
|
indices_trt = Tensor(name="indices",
|
|
shape=indices.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"))
|
|
|
|
output = tensorrt_llm.models.redrafter.redrafter_helper._batch_index_select(
|
|
x_trt, indices_trt)
|
|
output.mark_output("output")
|
|
|
|
# trt run
|
|
session = create_session(builder, network, precision='float32')
|
|
inputs = {
|
|
"x": x_data,
|
|
"indices": indices,
|
|
}
|
|
outputs = run_session(session, inputs)
|
|
|
|
# compare diff
|
|
torch.testing.assert_close(ref_res, outputs["output"])
|
|
torch.set_default_device(old_device)
|
|
return
|
|
|
|
|
|
########################################################################################################################
|
|
|
|
def test_prepare_next_input(self) -> None:
|
|
old_device = torch.get_default_device()
|
|
torch.set_default_device("cuda")
|
|
torch.manual_seed(17)
|
|
# test data
|
|
batch_size, num_candidates, candidate_len, vocab_size, hidden_size = 2, 4, 4, 1, 1
|
|
draft_log_probs = torch.rand(
|
|
[batch_size, num_candidates, candidate_len, vocab_size],
|
|
dtype=torch.float32)
|
|
base_log_probs = torch.rand(
|
|
[batch_size, num_candidates, candidate_len, vocab_size],
|
|
dtype=torch.float32)
|
|
last_base_log_probs = torch.rand(
|
|
[batch_size, num_candidates, vocab_size], dtype=torch.float32)
|
|
beam_index = torch.randint(0,
|
|
num_candidates, (batch_size, ),
|
|
dtype=torch.int32)
|
|
num_accept_tokens = torch.randint(0,
|
|
candidate_len, (batch_size, ),
|
|
dtype=torch.int32)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
with tensorrt_llm.net_guard(network):
|
|
draft_log_probs_trt = Tensor(
|
|
name="draft_log_probs",
|
|
shape=draft_log_probs.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
|
|
)
|
|
base_log_probs_trt = Tensor(
|
|
name="base_log_probs",
|
|
shape=base_log_probs.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
|
|
)
|
|
last_base_log_probs_trt = Tensor(
|
|
name="last_base_log_probs",
|
|
shape=last_base_log_probs.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("float32"),
|
|
)
|
|
beam_index_trt = Tensor(
|
|
name="beam_index",
|
|
shape=beam_index.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
|
|
)
|
|
num_accept_tokens_trt = Tensor(
|
|
name="num_accept_tokens",
|
|
shape=num_accept_tokens.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt("int32"),
|
|
)
|
|
probs = tensorrt_llm.models.redrafter.redrafter_helper._prepare_drafter_input(
|
|
draft_log_probs_trt,
|
|
base_log_probs_trt,
|
|
last_base_log_probs_trt,
|
|
beam_index_trt,
|
|
num_accept_tokens_trt,
|
|
)
|
|
probs.mark_output("probs")
|
|
|
|
# trt run
|
|
session = create_session(builder, network, precision='float32')
|
|
inputs = {
|
|
"draft_log_probs": draft_log_probs,
|
|
"base_log_probs": base_log_probs,
|
|
"last_base_log_probs": last_base_log_probs,
|
|
"beam_index": beam_index,
|
|
"num_accept_tokens": num_accept_tokens,
|
|
}
|
|
outputs = run_session(session, inputs)
|
|
|
|
ref_probs = torch.tensor([[0.1245], [0.3713]], dtype=torch.float32)
|
|
|
|
# compare diff
|
|
torch.testing.assert_close(ref_probs,
|
|
outputs["probs"],
|
|
atol=1e-4,
|
|
rtol=0.1)
|
|
torch.set_default_device(old_device)
|
|
return
|