mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> open source f8c0381a2bc50ee2739c3d8c2be481b31e5f00bd (#2736) Co-authored-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com> Add note for blackwell (#2742) Update the docs to workaround the extra-index-url issue (#2744) update README.md (#2751) Fix github io pages (#2761) Update
157 lines
5.8 KiB
Python
157 lines
5.8 KiB
Python
# # SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
# # SPDX-License-Identifier: Apache-2.0
|
|
# #
|
|
# # Licensed under the Apache License, Version 2.0 (the "License");
|
|
# # you may not use this file except in compliance with the License.
|
|
# # You may obtain a copy of the License at
|
|
# #
|
|
# # http://www.apache.org/licenses/LICENSE-2.0
|
|
# #
|
|
# # Unless required by applicable law or agreed to in writing, software
|
|
# # distributed under the License is distributed on an "AS IS" BASIS,
|
|
# # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# # See the License for the specific language governing permissions and
|
|
# # limitations under the License.
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from itertools import product
|
|
|
|
import numpy as np
|
|
import torch
|
|
from parameterized import parameterized
|
|
|
|
import tensorrt_llm
|
|
from tensorrt_llm import Tensor
|
|
|
|
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
|
from utils.util import create_session, run_session
|
|
|
|
|
|
class TestFunctional(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
tensorrt_llm.logger.set_level('error')
|
|
|
|
# TODO bhsueh rank 1 is not supported now
|
|
@parameterized.expand(
|
|
list(
|
|
product([[1], [2], [4], [2, 4], [8], [16], [8, 16], [1, 2, 4],
|
|
[1, 2, 4, 8, 16]])))
|
|
def test_ranks(self, lora_ranks_list):
|
|
print(f"[INFO] test lora_ranks_list: {lora_ranks_list}")
|
|
os.environ['LORA_USE_UNIFIED_GEMM'] = 'OFF'
|
|
torch.random.manual_seed(0)
|
|
|
|
dtype = 'float16'
|
|
torch_dtype = torch.float16
|
|
device = 'cuda'
|
|
|
|
batch_size = len(lora_ranks_list)
|
|
input_length = 32
|
|
hidden_size = 4096
|
|
input_data = [
|
|
torch.randn(input_length, hidden_size,
|
|
device=device).to(torch_dtype) * 0.1
|
|
for _ in range(batch_size)
|
|
]
|
|
lora_weight_ins = [
|
|
torch.randn(hidden_size, lora_rank, device=device).to(torch_dtype) *
|
|
0.1 for lora_rank in lora_ranks_list
|
|
]
|
|
lora_weight_outs = [
|
|
torch.randn(lora_rank, hidden_size, device=device).to(torch_dtype) *
|
|
0.1 for lora_rank in lora_ranks_list
|
|
]
|
|
host_context_lengths = torch.Tensor(
|
|
[input_length for _ in range(batch_size)]).to(torch.int32)
|
|
lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32)
|
|
|
|
ref_data = [
|
|
torch.matmul(torch.matmul(input, in_weight),
|
|
out_weight) for input, in_weight, out_weight in zip(
|
|
input_data, lora_weight_ins, lora_weight_outs)
|
|
]
|
|
|
|
lora_weight_ins = [
|
|
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_ins
|
|
]
|
|
lora_weight_outs = [
|
|
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs
|
|
]
|
|
|
|
lora_weights_pointers = []
|
|
for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs):
|
|
lora_weights_pointers.append(in_ptr.data_ptr())
|
|
lora_weights_pointers.append(out_ptr.data_ptr())
|
|
|
|
lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to(
|
|
torch.int64).reshape([batch_size, 2])
|
|
host_request_types = torch.zeros_like(host_context_lengths,
|
|
device='cpu').int()
|
|
|
|
concat_input_data = torch.concat(input_data).contiguous().to(device)
|
|
|
|
# construct trt network
|
|
builder = tensorrt_llm.Builder()
|
|
network = builder.create_network()
|
|
network.plugin_config.set_lora_plugin(dtype)
|
|
with tensorrt_llm.net_guard(network):
|
|
|
|
input_tensor = Tensor(name='input_tensor',
|
|
shape=concat_input_data.shape,
|
|
dtype=tensorrt_llm.str_dtype_to_trt(dtype))
|
|
host_request_types_tensor = Tensor(
|
|
name='host_request_types',
|
|
shape=[batch_size],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
host_context_lengths_tensor = Tensor(
|
|
name='host_context_lengths',
|
|
shape=[batch_size],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
lora_ranks_tensor = Tensor(
|
|
name='lora_ranks',
|
|
shape=[batch_size],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
|
lora_weights_pointers_tensor = Tensor(
|
|
name='lora_weights_pointers',
|
|
shape=[batch_size, 2],
|
|
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
|
|
|
output = tensorrt_llm.functional.lora_plugin(
|
|
input_tensor,
|
|
hidden_size,
|
|
[hidden_size],
|
|
host_request_types_tensor,
|
|
False,
|
|
True,
|
|
host_context_lengths_tensor,
|
|
max(max(lora_ranks_list), 8),
|
|
[lora_ranks_tensor],
|
|
[lora_weights_pointers_tensor],
|
|
weight_index=0,
|
|
)
|
|
output.mark_output('output')
|
|
|
|
# trt run
|
|
session = create_session(builder, network, precision=dtype)
|
|
inputs = {
|
|
'input_tensor': concat_input_data,
|
|
'host_request_types': host_request_types,
|
|
'host_context_lengths': host_context_lengths,
|
|
'lora_ranks': lora_ranks,
|
|
'lora_weights_pointers': lora_weights_pointers,
|
|
}
|
|
outputs = run_session(session, inputs)
|
|
|
|
# pytorch run
|
|
ref_data = torch.concat(ref_data)
|
|
# compare diff
|
|
|
|
dtype_atol = {"float16": 1e-2, "float32": 2e-3, "bfloat16": 1e-1}
|
|
|
|
np.testing.assert_allclose(ref_data.to(torch.float32).cpu().numpy(),
|
|
outputs['output'].to(
|
|
torch.float32).cpu().numpy(),
|
|
atol=dtype_atol[dtype])
|