mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
added loraOp into lora layer + test for mlp and comparison to lora plugin (#3455)
Loraop integration into torch modules Signed-off-by: Ubuntu <dafrimi@nvidia.com>
This commit is contained in:
parent
239fe0ff26
commit
0f084d9566
@ -70,7 +70,8 @@ add_library(
|
||||
userbuffersFinalizeOp.cpp
|
||||
userbuffersTensor.cpp
|
||||
weightOnlyQuantOp.cpp
|
||||
mtpOp.cpp)
|
||||
mtpOp.cpp
|
||||
loraOp.cpp)
|
||||
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
|
||||
${Python3_LIBRARIES} ${SHARED_TARGET})
|
||||
|
||||
192
cpp/tensorrt_llm/thop/loraOp.cpp
Normal file
192
cpp/tensorrt_llm/thop/loraOp.cpp
Normal file
@ -0,0 +1,192 @@
|
||||
|
||||
/*
|
||||
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cublasMMWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/opUtils.h"
|
||||
#include "tensorrt_llm/kernels/lora/lora.h"
|
||||
#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h"
|
||||
#include "tensorrt_llm/thop/thUtils.h"
|
||||
|
||||
namespace th = torch;
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
using tensorrt_llm::common::fmtstr;
|
||||
|
||||
namespace torch_ext
|
||||
{
|
||||
|
||||
enum class RequestType : int32_t
|
||||
{
|
||||
kCONTEXT = 0,
|
||||
kGENERATION = 1
|
||||
};
|
||||
|
||||
int64_t getNumTokens(th::Tensor const& input)
|
||||
{
|
||||
int ndim = input.sizes().size();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
3 == ndim || 2 == ndim, "hidden_state dimension should be either 2 [numTokens, hidden], or 3 [b, s, hidden]");
|
||||
int64_t num_tokens = input.sizes()[0];
|
||||
if (ndim == 3)
|
||||
{
|
||||
num_tokens *= input.sizes()[1];
|
||||
}
|
||||
return num_tokens;
|
||||
}
|
||||
|
||||
std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor const& host_request_types,
|
||||
std::vector<th::Tensor> const& lora_ranks, // numModules tensors, each tensors has single value
|
||||
std::vector<th::Tensor> const& lora_weights_pointers, th::Tensor const& host_context_lengths,
|
||||
std::vector<int64_t> const& output_hidden_sizes, bool transA, bool transB, int64_t const max_low_rank,
|
||||
int64_t const& weight_index, bool isRemoveInputPadding)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
auto const numReqs = lora_ranks[0].sizes()[0];
|
||||
auto const out_shape = input.sizes();
|
||||
int const numLoraModules = lora_ranks.size();
|
||||
TLLM_CHECK_WITH_INFO(lora_ranks.size() == lora_weights_pointers.size(), "both should be numLoraModules");
|
||||
std::vector<th::Tensor> output_torch;
|
||||
for (int i = 0; i < numLoraModules; i++)
|
||||
{
|
||||
std::vector<int64_t> output_shape = {out_shape[0], output_hidden_sizes[i]};
|
||||
if (!isRemoveInputPadding)
|
||||
{
|
||||
output_shape = {out_shape[0], out_shape[1], output_hidden_sizes[i]};
|
||||
}
|
||||
output_torch.push_back(torch::empty(output_shape, input.options()));
|
||||
}
|
||||
std::vector<void*> output;
|
||||
for (auto tensor_it = output_torch.begin(); tensor_it != output_torch.end(); tensor_it++)
|
||||
{
|
||||
output.push_back(tensor_it->data_ptr());
|
||||
}
|
||||
int const seqLen = isRemoveInputPadding ? 0 : input.sizes()[1];
|
||||
int32_t const* reqTypes = static_cast<int32_t const*>(host_request_types.data_ptr());
|
||||
int32_t const* hostContextLengths
|
||||
= isRemoveInputPadding ? static_cast<int32_t const*>(host_context_lengths.data_ptr()) : nullptr;
|
||||
|
||||
int64_t numTokens = getNumTokens(input);
|
||||
|
||||
std::vector<void const*> expandLoraWeightPtrs{};
|
||||
std::vector<int32_t> expandLoraRanks{};
|
||||
|
||||
expandLoraWeightPtrs.reserve(numLoraModules * numTokens * 2);
|
||||
expandLoraRanks.reserve(numLoraModules * numTokens);
|
||||
|
||||
for (int loraModuleIdx = 0; loraModuleIdx < numLoraModules; loraModuleIdx++)
|
||||
{
|
||||
auto const loraRankModule = static_cast<int32_t const*>(lora_ranks[loraModuleIdx].data_ptr());
|
||||
auto const loraWeightModulePtrs = static_cast<int64_t const*>(lora_weights_pointers[loraModuleIdx].data_ptr());
|
||||
|
||||
int idx = 0;
|
||||
for (int reqId = 0; reqId < numReqs; reqId++)
|
||||
{
|
||||
// loraWeightModulePtrs has 3 pointers for each module: A,B, and an optional DoRA magnitude
|
||||
// the current LoRA plugin does not apply DoRA scaling, so the magnitude is ignored
|
||||
RequestType const reqType = static_cast<RequestType const>(reqTypes[reqId]);
|
||||
if (reqType == RequestType::kGENERATION)
|
||||
{
|
||||
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3]));
|
||||
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3 + 1]));
|
||||
expandLoraRanks.push_back(loraRankModule[reqId]);
|
||||
idx += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
int contextLen = (isRemoveInputPadding ? hostContextLengths[reqId] : seqLen);
|
||||
|
||||
for (int contextId = 0; contextId < contextLen; contextId++)
|
||||
{
|
||||
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3]));
|
||||
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3 + 1]));
|
||||
expandLoraRanks.push_back(loraRankModule[reqId]);
|
||||
idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// In 1st generation phase cross attention qkv lora, cross qkv is skipped by passing an empty encoder_output
|
||||
// (passing 0 to dim) getNumTokens() will get in cross qkv_lora. Skipping the check for this case.
|
||||
if (numTokens > 0)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(idx == numTokens,
|
||||
fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %ld", idx, numTokens));
|
||||
}
|
||||
}
|
||||
|
||||
auto cublasHandle = getCublasHandle();
|
||||
auto cublasLtHandle = getCublasLtHandle();
|
||||
auto cublasWraper
|
||||
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
|
||||
|
||||
int const inHiddenSize = input.sizes()[input.sizes().size() - 1];
|
||||
|
||||
std::vector<int> outHiddenSizes(output_hidden_sizes.size());
|
||||
for (int i = 0; i < numLoraModules; i++)
|
||||
{
|
||||
outHiddenSizes[i] = output_hidden_sizes[i];
|
||||
}
|
||||
nvinfer1::DataType loraRuntimeDataType;
|
||||
switch (input.scalar_type())
|
||||
{
|
||||
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
|
||||
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
|
||||
default: throw std::invalid_argument("Invalid dtype, only supports float16, bfloat16");
|
||||
}
|
||||
|
||||
auto mLoraImpl = std::make_shared<tensorrt_llm::kernels::LoraImpl>(
|
||||
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper);
|
||||
|
||||
// TODO (dafrimi): use Profiler to find the best tactic as used in lora_plugin
|
||||
mLoraImpl->setBestTactic(std::nullopt);
|
||||
|
||||
auto const workspace_size = mLoraImpl->getWorkspaceSize(numTokens, numReqs, loraRuntimeDataType);
|
||||
|
||||
auto workspace = torch::empty(std::vector<int64_t>{static_cast<int64_t>(workspace_size)}, input.options());
|
||||
|
||||
mLoraImpl->run(numTokens, numReqs, input.data_ptr(), expandLoraRanks.data(), expandLoraWeightPtrs.data(),
|
||||
weight_index, output.data(), workspace.data_ptr(), stream);
|
||||
sync_check_cuda_error(stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return output_torch;
|
||||
}
|
||||
|
||||
} // namespace torch_ext
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
{
|
||||
m.def(
|
||||
"lora_grouped_gemm(Tensor input, "
|
||||
"Tensor host_request_types, "
|
||||
"Tensor [] lora_ranks, "
|
||||
"Tensor [] lora_weights_pointers, "
|
||||
"Tensor host_context_lengths, "
|
||||
"int [] output_hidden_sizes, "
|
||||
"bool transA, "
|
||||
"bool transB, "
|
||||
"int max_low_rank, "
|
||||
"int weight_index, "
|
||||
"bool isRemoveInputPadding) -> Tensor[]");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
{
|
||||
m.impl("lora_grouped_gemm", &torch_ext::lora_grouped_gemm);
|
||||
}
|
||||
@ -95,90 +95,61 @@ class LoraLayer(torch.nn.Module):
|
||||
def forward(self, x, lora_params: Dict,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
if bool(lora_params):
|
||||
|
||||
lora_ranks = []
|
||||
lora_weight_pointers = []
|
||||
lora_weight_tensors = [
|
||||
] # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr
|
||||
active_lora_module_ids = []
|
||||
for module_idx in self.lora_module_types:
|
||||
module_idx = int(module_idx)
|
||||
if module_idx in lora_params[layer_idx]:
|
||||
active_lora_module_ids.append(module_idx)
|
||||
|
||||
is_dora = lora_params[layer_idx][module_idx][
|
||||
'is_dora'] # todo (dafrimi) use it when calling loraOP
|
||||
|
||||
# TODO (dafrimi): needs to pass this is_dora arg
|
||||
lora_params[layer_idx][module_idx]['is_dora']
|
||||
lora_ranks.append(
|
||||
lora_params[layer_idx][module_idx]['adapter_size'])
|
||||
lora_weight_pointers.append(
|
||||
lora_params[layer_idx][module_idx]['weight_pointers'])
|
||||
lora_weight_tensors.append(
|
||||
lora_params[layer_idx][module_idx]['weight_tensors']
|
||||
) # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr
|
||||
|
||||
lora_params['num_seqs']
|
||||
num_seqs = lora_params['num_seqs']
|
||||
|
||||
if len(active_lora_module_ids) == 0:
|
||||
return None
|
||||
else:
|
||||
# If there's only one module, use the existing implementation
|
||||
if len(active_lora_module_ids) == 1:
|
||||
lora_weight_tensors = lora_weight_tensors[0]
|
||||
lora_output = (
|
||||
x @ lora_weight_tensors[1].T) @ lora_weight_tensors[0].T
|
||||
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
|
||||
x,
|
||||
lora_params['host_request_types'][:num_seqs],
|
||||
lora_ranks,
|
||||
lora_weight_pointers,
|
||||
lora_params['prompt_lens_cpu'][:num_seqs],
|
||||
self.output_hidden_sizes,
|
||||
False, # transA
|
||||
True, # transB
|
||||
max([r.max() for r in lora_ranks]),
|
||||
0,
|
||||
lora_params["remove_input_padding"],
|
||||
)
|
||||
if isinstance(lora_outputs, torch.Tensor):
|
||||
return lora_outputs
|
||||
else:
|
||||
# For multiple LoRA modules, some might not be executed in grouped gemm.
|
||||
# For those modules not executed, we create zero tensors with matching dimensions.
|
||||
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
|
||||
lora_output = []
|
||||
for module_idx in self.lora_module_types:
|
||||
if int(module_idx) in active_lora_module_ids:
|
||||
lora_output.append(lora_outputs.pop(0))
|
||||
else:
|
||||
lora_output.append(
|
||||
torch.zeros(list(x.shape[:-1]) + [
|
||||
self.output_hidden_sizes[
|
||||
self.lora_module_types.index(
|
||||
module_idx)]
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device))
|
||||
|
||||
lora_output = torch.cat(lora_output, dim=-1)
|
||||
return lora_output
|
||||
|
||||
# For multiple modules, compute and concatenate outputs
|
||||
lora_outputs = []
|
||||
for module_idx in self.lora_module_types:
|
||||
module_idx = int(module_idx)
|
||||
if module_idx in active_lora_module_ids:
|
||||
i = active_lora_module_ids.index(module_idx)
|
||||
weight_tensors = lora_weight_tensors[i]
|
||||
A, B = weight_tensors[0], weight_tensors[1]
|
||||
lora_output = (x @ B.T) @ A.T
|
||||
lora_outputs.append(lora_output)
|
||||
|
||||
# Concatenate outputs from all modules
|
||||
lora_output = torch.cat(lora_outputs, dim=-1)
|
||||
return lora_output
|
||||
|
||||
# TODO(dafrimi): use torch implementation. For now, this is just a placeholder, until we will do the biniding to lora ops C++
|
||||
# lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
|
||||
# x,
|
||||
# lora_params['host_request_types'][:num_seqs],
|
||||
# lora_ranks,
|
||||
# lora_weight_pointers,
|
||||
# lora_params['prompt_lens_cpu'][:num_seqs],
|
||||
# self.output_hidden_sizes,
|
||||
# False, # transA
|
||||
# True, # transB
|
||||
# max([r.max() for r in lora_ranks]),
|
||||
# 0,
|
||||
# True,
|
||||
# )
|
||||
# if isinstance(lora_outputs, torch.Tensor):
|
||||
# return lora_outputs
|
||||
# else:
|
||||
# # For multiple LoRA modules, some might not be executed in grouped gemm.
|
||||
# # For those modules not executed, we create zero tensors with matching dimensions.
|
||||
# # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
|
||||
# lora_output = []
|
||||
# for module_idx in self.lora_module_types:
|
||||
# if int(module_idx) in active_lora_module_ids:
|
||||
# lora_output.append(lora_outputs.pop(0))
|
||||
# else:
|
||||
# lora_output.append(
|
||||
# torch.zeros(list(x.shape[:-1]) + [
|
||||
# self.output_hidden_sizes[
|
||||
# self.lora_module_types.index(
|
||||
# module_idx)]
|
||||
# ],
|
||||
# dtype=x.dtype,
|
||||
# device=x.device))
|
||||
|
||||
# lora_output = torch.cat(lora_output, dim=-1)
|
||||
# return lora_output
|
||||
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -12,11 +12,13 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_llama import LlamaAttention
|
||||
# LoRA Imports
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraModuleType
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm._utils import str_dtype_to_torch
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.layers import Attention
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.layers import AttentionParams, KeyValueCacheParams
|
||||
from tensorrt_llm.layers.lora import Lora, LoraParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -27,94 +29,34 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
|
||||
def setUpClass(cls):
|
||||
cls.batch_size = 1
|
||||
cls.seq_len = 16
|
||||
cls.hidden_size = 64
|
||||
cls.head_num = 1
|
||||
cls.num_hidden_layers = 1
|
||||
cls.head_size = 64
|
||||
cls.hidden_size = cls.head_num * cls.head_size
|
||||
cls.dtype = 'float16'
|
||||
cls.torch_dtype = str_dtype_to_torch(cls.dtype)
|
||||
cls.device = torch.device('cuda')
|
||||
cls.pos_emb_type = PositionEmbeddingType.learned_absolute
|
||||
cls.causal_mask = True
|
||||
|
||||
# KV cache parameters
|
||||
cls.num_blocks = 4
|
||||
cls.tokens_per_block = 32
|
||||
def _create_lora_params(self, ):
|
||||
lora_ranks_list = [8 for _ in range(self.batch_size)]
|
||||
|
||||
cls.llama_config = LlamaConfig(hidden_size=cls.hidden_size,
|
||||
num_attention_heads=cls.head_num,
|
||||
num_hidden_layers=cls.num_hidden_layers,
|
||||
intermediate_size=256,
|
||||
max_position_embeddings=512,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=32000,
|
||||
num_key_value_heads=cls.head_num,
|
||||
torch_dtype=cls.torch_dtype)
|
||||
|
||||
# Create KV cache manager
|
||||
head_dim = cls.llama_config.hidden_size // cls.llama_config.num_attention_heads
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=cls.num_blocks *
|
||||
cls.tokens_per_block)
|
||||
cls.kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=cls.llama_config.num_hidden_layers,
|
||||
num_kv_heads=cls.llama_config.num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=cls.tokens_per_block,
|
||||
max_seq_len=cls.num_blocks * cls.tokens_per_block,
|
||||
max_batch_size=cls.batch_size,
|
||||
mapping=mapping,
|
||||
dtype=tensorrt_llm.bindings.DataType.HALF)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.kv_cache_manager.shutdown()
|
||||
|
||||
def _create_attention_inputs(self):
|
||||
hidden_states = torch.empty(
|
||||
size=[self.batch_size, self.seq_len, self.hidden_size],
|
||||
dtype=self.torch_dtype,
|
||||
device='cuda')
|
||||
hidden_states.normal_(0.0, 0.02)
|
||||
|
||||
# Create weights
|
||||
q_weight = torch.empty(size=[self.hidden_size, self.hidden_size],
|
||||
dtype=self.torch_dtype)
|
||||
torch.nn.init.xavier_uniform_(q_weight)
|
||||
|
||||
# Set K and V and O weights to identity matrix
|
||||
eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype)
|
||||
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
|
||||
out_weight = eye_weight
|
||||
|
||||
return hidden_states, qkv_weight, out_weight
|
||||
|
||||
def _create_lora_params(self):
|
||||
lora_ranks_list = [8]
|
||||
|
||||
host_context_lengths = torch.Tensor(
|
||||
[self.seq_len for _ in range(self.batch_size)]).to(torch.int32)
|
||||
lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32)
|
||||
host_request_types = torch.zeros_like(host_context_lengths,
|
||||
device='cpu').int()
|
||||
|
||||
lora_weight_ins = [
|
||||
torch.randn(self.hidden_size, lora_rank, device=self.device).to(
|
||||
torch.randn(self.hidden_size, lora_rank, device="cuda").to(
|
||||
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
|
||||
]
|
||||
lora_weight_outs = [
|
||||
torch.randn(lora_rank, self.hidden_size, device=self.device).to(
|
||||
torch.randn(lora_rank, self.hidden_size, device="cuda").to(
|
||||
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
|
||||
]
|
||||
|
||||
lora_weight_ins = [
|
||||
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_ins
|
||||
]
|
||||
lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins]
|
||||
lora_weight_outs = [
|
||||
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs
|
||||
]
|
||||
|
||||
# Create weight pointers for TensorRT
|
||||
lora_weights_pointers = []
|
||||
for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs):
|
||||
lora_weights_pointers.append(in_ptr.data_ptr())
|
||||
@ -125,88 +67,164 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
|
||||
|
||||
return {
|
||||
'lora_ranks': lora_ranks,
|
||||
'host_context_lengths': host_context_lengths,
|
||||
'host_request_types': host_request_types,
|
||||
'lora_weights_pointers': lora_weights_pointers,
|
||||
'lora_weight_ins': lora_weight_ins,
|
||||
'lora_weight_outs': lora_weight_outs
|
||||
}
|
||||
|
||||
def _setup_attention_module(self, qkv_weight, out_weight):
|
||||
"""Set up the attention module with weights."""
|
||||
model_config = ModelConfig(pretrained_config=self.llama_config,
|
||||
attn_backend="VANILLA")
|
||||
layer_idx = 0
|
||||
attention_module = LlamaAttention(model_config, layer_idx=layer_idx).to(
|
||||
self.device).to(self.torch_dtype)
|
||||
def test_lora_attention(self):
|
||||
|
||||
# Set weights
|
||||
attention_module.qkv_proj.weight.data = torch.from_numpy(
|
||||
np.ascontiguousarray(qkv_weight.cpu().numpy().transpose(
|
||||
[1, 0]))).to(self.device)
|
||||
attention_module.o_proj.weight.data = torch.from_numpy(
|
||||
np.ascontiguousarray(out_weight.cpu().numpy().transpose(
|
||||
[1, 0]))).to(self.device)
|
||||
mean = 0.0
|
||||
std_dev = 0.02 if self.dtype == "float32" else 0.005
|
||||
|
||||
return attention_module, model_config
|
||||
hidden_states = torch.concat([
|
||||
torch.empty(size=[self.seq_len, self.hidden_size],
|
||||
dtype=self.torch_dtype,
|
||||
device=self.device).normal_(mean, std_dev)
|
||||
for _ in range(self.batch_size)
|
||||
])
|
||||
|
||||
def _create_attention_metadata(self, model_config):
|
||||
sequence_lengths = [self.seq_len]
|
||||
past_seen_tokens = [0]
|
||||
request_ids = [0]
|
||||
token_nums = [self.seq_len]
|
||||
prompt_lens = token_nums
|
||||
context_lengths = torch.full([self.batch_size],
|
||||
self.seq_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
self.kv_cache_manager.add_dummy_requests(request_ids, token_nums)
|
||||
# Plugin specific setup - only generate 1 step
|
||||
max_seq_len = self.seq_len + 1
|
||||
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
return metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32),
|
||||
num_contexts=len(sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=self.kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=prompt_lens,
|
||||
max_num_requests=self.batch_size,
|
||||
max_num_tokens=self.batch_size * self.seq_len,
|
||||
)
|
||||
# zero means "valid" token, one means invalid.
|
||||
host_past_key_value_lengths = torch.tensor([0] * self.batch_size,
|
||||
dtype=torch.int32)
|
||||
|
||||
# the max kv cache length for each layer. single tensor since we only have 1 layer here.
|
||||
host_max_attention_window_sizes = torch.tensor([max_seq_len],
|
||||
dtype=torch.int32)
|
||||
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
|
||||
|
||||
sequence_length = torch.full([self.batch_size],
|
||||
self.seq_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
# even in the the context phase, kv cache tensors can not be empty tensor for plugin, the actual shape info
|
||||
# otherwise, there will be cublas execution error.
|
||||
# are passed to plugin by the `sequence_length` tensor
|
||||
kv_shape = (self.batch_size, 2, self.head_num, max_seq_len,
|
||||
self.head_size)
|
||||
past_key_value = torch.randn(kv_shape,
|
||||
dtype=self.torch_dtype,
|
||||
device=self.device)
|
||||
cache_indirection = torch.full((
|
||||
self.batch_size,
|
||||
1,
|
||||
max_seq_len,
|
||||
),
|
||||
0,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
host_request_types = torch.tensor([0] * self.batch_size,
|
||||
dtype=torch.int32,
|
||||
device='cpu')
|
||||
|
||||
perf_knob_tensor_size = 16
|
||||
host_runtime_perf_knobs_tensor = torch.tensor([-1] *
|
||||
perf_knob_tensor_size,
|
||||
dtype=torch.int64,
|
||||
device='cpu')
|
||||
host_context_progress = torch.tensor([0],
|
||||
dtype=torch.int64,
|
||||
device='cpu')
|
||||
|
||||
host_context_lengths = torch.Tensor(
|
||||
[self.seq_len for _ in range(self.batch_size)]).to(torch.int32)
|
||||
|
||||
q_weight = torch.empty(size=[self.hidden_size, self.hidden_size],
|
||||
dtype=self.torch_dtype)
|
||||
torch.nn.init.xavier_uniform_(q_weight)
|
||||
|
||||
# The initialization here is chosen to minimize computation after the
|
||||
# QKV BMMs in order to reduce the amount of differences from FP accumulation.
|
||||
# We set K and V weights to the identity matrix so that the input is copied
|
||||
# without doing any accumulation. Additionally, we set the output projection
|
||||
# to the identity for the same reason.
|
||||
# The main purpose of these tests is to check the QK^T BMM + Softmax + SV BMM for LoRA.
|
||||
eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype)
|
||||
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
|
||||
|
||||
out_weight = eye_weight
|
||||
|
||||
lora_params = self._create_lora_params()
|
||||
|
||||
def _setup_trt_network(self, hidden_states, lora_params, attention_module):
|
||||
builder = tensorrt_llm.Builder()
|
||||
net = builder.create_network()
|
||||
net.plugin_config.to_legacy_setting()
|
||||
net.plugin_config.lora_plugin = self.dtype
|
||||
|
||||
net.plugin_config.gpt_attention_plugin = self.dtype # for ragged input we use this plugin with remove_input_padding
|
||||
net.plugin_config.remove_input_padding = True
|
||||
net.plugin_config.lora_plugin = "float16"
|
||||
with tensorrt_llm.net_guard(net):
|
||||
# Create input tensor
|
||||
trt_hidden_states = Tensor(name='hidden_states',
|
||||
shape=hidden_states.shape,
|
||||
dtype=tensorrt_llm.str_dtype_to_trt(
|
||||
self.dtype))
|
||||
|
||||
# Create LoRA tensors
|
||||
host_request_types_tensor = Tensor(
|
||||
name='host_request_types',
|
||||
shape=[lora_params['host_request_types'].shape[0]],
|
||||
context_lengths_tensor = Tensor(
|
||||
name='context_lengths',
|
||||
shape=context_lengths.shape,
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
|
||||
host_context_lengths_tensor = Tensor(
|
||||
name='host_context_lengths',
|
||||
shape=[lora_params['host_context_lengths'].shape[0]],
|
||||
shape=[host_context_lengths.shape[0]],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
|
||||
lora_ranks_tensor = Tensor(
|
||||
name='lora_ranks',
|
||||
shape=[lora_params['lora_ranks'].shape[0]],
|
||||
shape=(lora_params['lora_ranks'].shape[0], ),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
|
||||
lora_weights_pointers_tensor = Tensor(
|
||||
name='lora_weights_pointers',
|
||||
shape=lora_params['lora_weights_pointers'].shape,
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
||||
|
||||
# Create LoRA parameters
|
||||
lora_params = LoraParams(
|
||||
host_request_types_tensor = Tensor(
|
||||
name='host_request_types',
|
||||
shape=host_request_types.shape,
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
past_key_value_tensor = Tensor(name='past_key_value',
|
||||
shape=tuple(past_key_value.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt(
|
||||
self.dtype))
|
||||
sequence_length_tensor = Tensor(
|
||||
name='sequence_length',
|
||||
shape=tuple(sequence_length.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
host_past_key_value_lengths_tensor = Tensor(
|
||||
name='host_past_key_value_lengths',
|
||||
shape=tuple(host_past_key_value_lengths.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
host_max_attention_window_sizes_tensor = Tensor(
|
||||
name='host_max_attention_window_sizes',
|
||||
shape=tuple(host_max_attention_window_sizes.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
host_sink_token_length_tensor = Tensor(
|
||||
name='host_sink_token_length',
|
||||
shape=tuple(host_sink_token_length.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
cache_indirection_tensor = Tensor(
|
||||
name='cache_indirection',
|
||||
shape=tuple(cache_indirection.shape),
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
host_runtime_perf_knobs = Tensor(
|
||||
name='host_runtime_perf_knobs',
|
||||
shape=[16],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
||||
host_context_progress_tensor = Tensor(
|
||||
name='host_context_progress',
|
||||
shape=[1],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
||||
|
||||
lora_layer_params = LoraParams(
|
||||
lora_ranks=[{
|
||||
"attn_q_lora_ranks": lora_ranks_tensor,
|
||||
"attn_k_lora_ranks": lora_ranks_tensor,
|
||||
@ -224,15 +242,17 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
|
||||
lora_weights_pointers_tensor,
|
||||
}],
|
||||
host_context_lengths=host_context_lengths_tensor,
|
||||
host_request_types=host_request_types_tensor)
|
||||
host_request_types=host_request_types_tensor,
|
||||
)
|
||||
|
||||
attn_layer = Attention(
|
||||
attn_layer = tensorrt_llm.layers.Attention(
|
||||
local_layer_idx=0,
|
||||
hidden_size=hidden_states.shape[-1],
|
||||
num_attention_heads=1,
|
||||
max_position_embeddings=hidden_states.shape[1],
|
||||
hidden_size=self.hidden_size,
|
||||
num_attention_heads=self.head_num,
|
||||
max_position_embeddings=self.seq_len,
|
||||
attention_mask_type=tensorrt_llm.layers.AttentionMaskType.
|
||||
causal,
|
||||
position_embedding_type=self.pos_emb_type,
|
||||
bias=False)
|
||||
|
||||
attn_layer.qkv_lora = Lora(
|
||||
@ -254,29 +274,62 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
|
||||
max_low_rank=8,
|
||||
)
|
||||
|
||||
# Set attention layer weights
|
||||
attn_layer.qkv.weight.value = attention_module.qkv_proj.weight.data
|
||||
attn_layer.dense.weight.value = attention_module.o_proj.weight.data
|
||||
attn_layer.qkv.weight.value = np.ascontiguousarray(
|
||||
qkv_weight.cpu().numpy().transpose([1, 0]))
|
||||
attn_layer.dense.weight.value = np.ascontiguousarray(
|
||||
out_weight.cpu().numpy().transpose([1, 0]))
|
||||
|
||||
output = attn_layer(hidden_states=trt_hidden_states,
|
||||
lora_layer_params=lora_params)
|
||||
output, present_key_value = attn_layer(
|
||||
trt_hidden_states,
|
||||
use_cache=True,
|
||||
lora_layer_params=
|
||||
lora_layer_params, # Always use cache for plugin path in this test
|
||||
kv_cache_params=KeyValueCacheParams(
|
||||
past_key_value=[past_key_value_tensor],
|
||||
host_past_key_value_lengths=
|
||||
host_past_key_value_lengths_tensor,
|
||||
host_max_attention_window_sizes=
|
||||
host_max_attention_window_sizes_tensor,
|
||||
host_sink_token_length=host_sink_token_length_tensor,
|
||||
cache_indirection=cache_indirection_tensor),
|
||||
attention_params=AttentionParams(
|
||||
sequence_length=sequence_length_tensor,
|
||||
context_lengths=context_lengths_tensor,
|
||||
host_request_types=host_request_types_tensor,
|
||||
max_context_length=self.seq_len,
|
||||
host_runtime_perf_knobs=host_runtime_perf_knobs,
|
||||
host_context_progress=host_context_progress_tensor,
|
||||
host_context_lengths=host_context_lengths_tensor,
|
||||
))
|
||||
|
||||
assert isinstance(output, Tensor)
|
||||
output.mark_output('output',
|
||||
tensorrt_llm.str_dtype_to_trt(self.dtype))
|
||||
present_key_value.mark_output(
|
||||
'present_key_value', tensorrt_llm.str_dtype_to_trt(self.dtype))
|
||||
|
||||
return builder, net
|
||||
|
||||
def _run_trt_inference(self, builder, net, hidden_states, lora_params):
|
||||
builder_config = builder.create_builder_config(name='attention',
|
||||
builder_config = builder.create_builder_config(name='attention_plugin',
|
||||
precision=self.dtype)
|
||||
|
||||
engine_buffer = builder.build_engine(net, builder_config)
|
||||
session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
||||
engine_buffer)
|
||||
|
||||
stream = torch.cuda.current_stream().cuda_stream
|
||||
|
||||
inputs = {
|
||||
'hidden_states': hidden_states,
|
||||
'host_request_types': lora_params['host_request_types'],
|
||||
'host_context_lengths': lora_params['host_context_lengths'],
|
||||
'past_key_value': past_key_value,
|
||||
'sequence_length': sequence_length,
|
||||
'host_past_key_value_lengths': host_past_key_value_lengths,
|
||||
'host_max_attention_window_sizes': host_max_attention_window_sizes,
|
||||
'host_sink_token_length': host_sink_token_length,
|
||||
'context_lengths': context_lengths,
|
||||
'host_request_types': host_request_types,
|
||||
'cache_indirection': cache_indirection,
|
||||
'host_runtime_perf_knobs': host_runtime_perf_knobs_tensor,
|
||||
'host_context_progress': host_context_progress,
|
||||
'host_context_lengths': host_context_lengths,
|
||||
'lora_ranks': lora_params['lora_ranks'],
|
||||
'lora_weights_pointers': lora_params['lora_weights_pointers'],
|
||||
}
|
||||
@ -286,108 +339,128 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
|
||||
torch.empty(hidden_states.shape,
|
||||
dtype=tensorrt_llm._utils.str_dtype_to_torch(
|
||||
self.dtype),
|
||||
device='cuda'),
|
||||
device=self.device),
|
||||
'present_key_value':
|
||||
past_key_value,
|
||||
}
|
||||
|
||||
session.run(inputs=inputs, outputs=outputs, stream=stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return outputs['output'].squeeze(0)
|
||||
# Pytorch flow
|
||||
llama_config = LlamaConfig(hidden_size=self.hidden_size,
|
||||
num_attention_heads=self.head_num,
|
||||
num_hidden_layers=1,
|
||||
intermediate_size=256,
|
||||
max_position_embeddings=512,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=32000,
|
||||
num_key_value_heads=self.head_num,
|
||||
torch_dtype=self.torch_dtype)
|
||||
|
||||
def test_attention_with_lora(self):
|
||||
hidden_states, qkv_weight, out_weight = self._create_attention_inputs()
|
||||
mapping = Mapping(world_size=1, tp_size=1, rank=0)
|
||||
kv_cache_config = KvCacheConfig(max_tokens=max_seq_len)
|
||||
head_dim = llama_config.hidden_size // llama_config.num_attention_heads
|
||||
kv_cache_manager = KVCacheManager(
|
||||
kv_cache_config=kv_cache_config,
|
||||
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
|
||||
CacheType.SELF,
|
||||
num_layers=llama_config.num_hidden_layers,
|
||||
num_kv_heads=llama_config.num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=128,
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=self.batch_size,
|
||||
mapping=mapping,
|
||||
dtype=tensorrt_llm.bindings.DataType.HALF)
|
||||
|
||||
lora_params = self._create_lora_params()
|
||||
model_config = ModelConfig(pretrained_config=llama_config,
|
||||
attn_backend="VANILLA")
|
||||
attention_module = LlamaAttention(model_config, layer_idx=0).to(
|
||||
self.device).to(self.torch_dtype)
|
||||
|
||||
attention_module, model_config = self._setup_attention_module(
|
||||
qkv_weight, out_weight)
|
||||
attention_module.qkv_proj.weight.data = torch.from_numpy(
|
||||
np.ascontiguousarray(qkv_weight.cpu().numpy().transpose(
|
||||
[1, 0]))).to(self.device)
|
||||
attention_module.o_proj.weight.data = torch.from_numpy(
|
||||
np.ascontiguousarray(out_weight.cpu().numpy().transpose(
|
||||
[1, 0]))).to(self.device)
|
||||
|
||||
attn_metadata = self._create_attention_metadata(model_config)
|
||||
builder, net = self._setup_trt_network(hidden_states, lora_params,
|
||||
attention_module)
|
||||
trt_output = self._run_trt_inference(builder, net, hidden_states,
|
||||
lora_params)
|
||||
request_ids = [0]
|
||||
|
||||
lora_params_pytorchflow = {
|
||||
kv_cache_manager.add_dummy_requests(request_ids=request_ids,
|
||||
token_nums=[self.seq_len])
|
||||
sequence_lengths = [self.seq_len]
|
||||
past_seen_tokens = [0]
|
||||
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
|
||||
attn_metadata = metadata_cls(
|
||||
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32),
|
||||
num_contexts=len(sequence_lengths),
|
||||
kv_cache_params=KVCacheParams(
|
||||
use_cache=True,
|
||||
num_cached_tokens_per_seq=past_seen_tokens,
|
||||
),
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
request_ids=request_ids,
|
||||
prompt_lens=sequence_lengths,
|
||||
max_num_requests=self.batch_size,
|
||||
max_num_tokens=self.batch_size * self.seq_len,
|
||||
)
|
||||
|
||||
|
||||
lora_params_pytorch_flow = {
|
||||
'num_seqs': self.batch_size,
|
||||
'host_request_types': torch.zeros(self.batch_size,
|
||||
dtype=torch.int32),
|
||||
'prompt_lens_cpu': torch.tensor([self.seq_len] * self.batch_size),
|
||||
'host_request_types':host_request_types,
|
||||
'prompt_lens_cpu': host_context_lengths,
|
||||
'remove_input_padding': True,
|
||||
0: { # layer_idx
|
||||
LoraModuleType.ATTENTION_Q: { # Module type
|
||||
'adapter_size':
|
||||
torch.tensor([8]),
|
||||
lora_params['lora_ranks'],
|
||||
'weight_pointers':
|
||||
torch.tensor([[
|
||||
lora_params['lora_weight_outs'][0].data_ptr(),
|
||||
lora_params['lora_weight_ins'][0].data_ptr()
|
||||
]]),
|
||||
lora_params['lora_weights_pointers'],
|
||||
'is_dora':
|
||||
False,
|
||||
'weight_tensors': [
|
||||
lora_params['lora_weight_outs'][0],
|
||||
lora_params['lora_weight_ins'][0]
|
||||
]
|
||||
},
|
||||
LoraModuleType.ATTENTION_K: {
|
||||
'adapter_size':
|
||||
torch.tensor([8]),
|
||||
'weight_pointers':
|
||||
torch.tensor([[
|
||||
lora_params['lora_weight_outs'][0].data_ptr(),
|
||||
lora_params['lora_weight_ins'][0].data_ptr()
|
||||
]]),
|
||||
lora_params['lora_ranks'],
|
||||
'weight_pointers': lora_params['lora_weights_pointers'],
|
||||
'is_dora':
|
||||
False,
|
||||
'weight_tensors': [
|
||||
lora_params['lora_weight_outs'][0],
|
||||
lora_params['lora_weight_ins'][0]
|
||||
]
|
||||
},
|
||||
LoraModuleType.ATTENTION_V: {
|
||||
'adapter_size':
|
||||
torch.tensor([8]),
|
||||
lora_params['lora_ranks'],
|
||||
'weight_pointers':
|
||||
torch.tensor([[
|
||||
lora_params['lora_weight_outs'][0].data_ptr(),
|
||||
lora_params['lora_weight_ins'][0].data_ptr()
|
||||
]]),
|
||||
lora_params['lora_weights_pointers'],
|
||||
'is_dora':
|
||||
False,
|
||||
'weight_tensors': [
|
||||
lora_params['lora_weight_outs'][0],
|
||||
lora_params['lora_weight_ins'][0]
|
||||
]
|
||||
},
|
||||
LoraModuleType.ATTENTION_DENSE: {
|
||||
'adapter_size':
|
||||
torch.tensor([8]),
|
||||
lora_params['lora_ranks'],
|
||||
'weight_pointers':
|
||||
torch.tensor([[
|
||||
lora_params['lora_weight_outs'][0].data_ptr(),
|
||||
lora_params['lora_weight_ins'][0].data_ptr()
|
||||
]]),
|
||||
lora_params['lora_weights_pointers'],
|
||||
'is_dora':
|
||||
False,
|
||||
'weight_tensors': [
|
||||
lora_params['lora_weight_outs'][0],
|
||||
lora_params['lora_weight_ins'][0]
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
with torch.inference_mode():
|
||||
attn_metadata.prepare()
|
||||
hidden_states_pytorchflow = hidden_states.squeeze(0)
|
||||
pytorchflow_output = attention_module(
|
||||
|
||||
pytorch_flow_output = attention_module(
|
||||
position_ids=None,
|
||||
hidden_states=hidden_states_pytorchflow,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
attention_mask=PredefinedAttentionMask.CAUSAL,
|
||||
lora_params=lora_params_pytorchflow)
|
||||
lora_params=lora_params_pytorch_flow)
|
||||
|
||||
torch.testing.assert_close(pytorchflow_output,
|
||||
trt_output = outputs['output']
|
||||
|
||||
torch.testing.assert_close(pytorch_flow_output,
|
||||
trt_output,
|
||||
atol=2e-3,
|
||||
rtol=0)
|
||||
|
||||
@ -0,0 +1,180 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import Tensor
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
|
||||
from utils.util import create_session, run_session
|
||||
|
||||
|
||||
class TestLoraPluginVsLayer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
tensorrt_llm.logger.set_level('info')
|
||||
torch.random.manual_seed(0)
|
||||
self.dtype = 'float16'
|
||||
self.torch_dtype = torch.float16
|
||||
self.device = 'cuda'
|
||||
self.batch_size = 4
|
||||
self.seq_len = 8
|
||||
self.hidden_size = 1024
|
||||
self.lora_rank = 8
|
||||
self.is_remove_input_padding = False
|
||||
self.weight_index = 0
|
||||
self.transA = False
|
||||
self.transB = True
|
||||
|
||||
def _create_input_tensors(self, batch_size, seq_len, hidden_size,
|
||||
lora_ranks_list):
|
||||
input_tensor = torch.randn(batch_size,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
dtype=self.torch_dtype,
|
||||
device=self.device) * 0.1
|
||||
|
||||
lora_weight_ins = [
|
||||
torch.randn(hidden_size, lora_rank, device=self.device).to(
|
||||
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
|
||||
]
|
||||
lora_weight_outs = [
|
||||
torch.randn(lora_rank, hidden_size, device=self.device).to(
|
||||
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
|
||||
]
|
||||
|
||||
lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins]
|
||||
lora_weight_outs = [
|
||||
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs
|
||||
]
|
||||
|
||||
# Create LoRA weight pointers
|
||||
lora_weights_pointers = []
|
||||
for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs):
|
||||
lora_weights_pointers.append(in_ptr.data_ptr())
|
||||
lora_weights_pointers.append(out_ptr.data_ptr())
|
||||
# null dora scale
|
||||
lora_weights_pointers.append(0)
|
||||
|
||||
lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to(
|
||||
torch.int64).reshape([batch_size, 3])
|
||||
|
||||
# Create other tensors
|
||||
host_context_lengths = torch.Tensor(
|
||||
[seq_len for _ in range(batch_size)]).to(torch.int32)
|
||||
lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32)
|
||||
host_request_types = torch.zeros_like(host_context_lengths,
|
||||
device='cpu').int()
|
||||
|
||||
return {
|
||||
'input_tensor': input_tensor,
|
||||
'lora_weight_ins': lora_weight_ins,
|
||||
'lora_weight_outs': lora_weight_outs,
|
||||
'lora_weights_pointers': lora_weights_pointers,
|
||||
'host_context_lengths': host_context_lengths,
|
||||
'lora_ranks': lora_ranks,
|
||||
'host_request_types': host_request_types,
|
||||
'batch_size': batch_size,
|
||||
'seq_len': seq_len,
|
||||
'hidden_size': hidden_size,
|
||||
'max_lora_rank': max(max(lora_ranks_list), 8)
|
||||
}
|
||||
|
||||
def _create_lora_plugin_session(self, tensors):
|
||||
# Construct TensorRT network
|
||||
builder = tensorrt_llm.Builder()
|
||||
network = builder.create_network()
|
||||
network.plugin_config.set_lora_plugin(self.dtype)
|
||||
network.plugin_config.remove_input_padding = self.is_remove_input_padding
|
||||
|
||||
with tensorrt_llm.net_guard(network):
|
||||
input_tensor = Tensor(name='input_tensor',
|
||||
shape=[
|
||||
tensors['batch_size'], tensors['seq_len'],
|
||||
tensors['hidden_size']
|
||||
],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt(
|
||||
self.dtype))
|
||||
host_request_types_tensor = Tensor(
|
||||
name='host_request_types',
|
||||
shape=[tensors['batch_size']],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
host_context_lengths_tensor = Tensor(
|
||||
name='host_context_lengths',
|
||||
shape=[tensors['batch_size']],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
lora_ranks_tensor = Tensor(
|
||||
name='lora_ranks',
|
||||
shape=[tensors['batch_size']],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
|
||||
lora_weights_pointers_tensor = Tensor(
|
||||
name='lora_weights_pointers',
|
||||
shape=[tensors['batch_size'], 3],
|
||||
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
|
||||
|
||||
output = tensorrt_llm.functional.lora_plugin(
|
||||
input_tensor,
|
||||
tensors['hidden_size'],
|
||||
[tensors['hidden_size']],
|
||||
host_request_types_tensor,
|
||||
self.transA, # transA
|
||||
self.transB, # transB
|
||||
host_context_lengths_tensor,
|
||||
tensors['max_lora_rank'],
|
||||
[lora_ranks_tensor],
|
||||
[lora_weights_pointers_tensor],
|
||||
weight_index=self.weight_index,
|
||||
)
|
||||
output.mark_output('output')
|
||||
|
||||
return create_session(builder, network, precision=self.dtype)
|
||||
|
||||
def _run_lora_grouped_gemm(self, tensors):
|
||||
"""Run the lora_grouped_gemm operation directly"""
|
||||
# Prepare parameters for lora_grouped_gemm
|
||||
x = tensors['input_tensor']
|
||||
host_request_types = tensors[
|
||||
'host_request_types'][:tensors['batch_size']]
|
||||
lora_ranks = tensors['lora_ranks']
|
||||
lora_weight_pointers = tensors['lora_weights_pointers']
|
||||
prompt_lens_cpu = tensors[
|
||||
'host_context_lengths'][:tensors['batch_size']]
|
||||
output_hidden_sizes = [tensors['hidden_size']]
|
||||
transA = self.transA
|
||||
transB = self.transB
|
||||
max_rank = max([r.item() for r in lora_ranks])
|
||||
weight_index = self.weight_index
|
||||
is_remove_input_padding = self.is_remove_input_padding
|
||||
|
||||
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
|
||||
x, host_request_types, [lora_ranks], [lora_weight_pointers],
|
||||
prompt_lens_cpu, output_hidden_sizes, transA, transB, max_rank,
|
||||
weight_index, is_remove_input_padding)
|
||||
|
||||
return lora_outputs[0]
|
||||
|
||||
def test_lora_plugin_vs_lora_op(self):
|
||||
lora_ranks_list = [self.lora_rank] * self.batch_size
|
||||
|
||||
tensors = self._create_input_tensors(self.batch_size, self.seq_len,
|
||||
self.hidden_size, lora_ranks_list)
|
||||
|
||||
session = self._create_lora_plugin_session(tensors)
|
||||
inputs = {
|
||||
'input_tensor': tensors['input_tensor'],
|
||||
'host_request_types': tensors['host_request_types'],
|
||||
'host_context_lengths': tensors['host_context_lengths'],
|
||||
'lora_ranks': tensors['lora_ranks'],
|
||||
'lora_weights_pointers': tensors['lora_weights_pointers'],
|
||||
}
|
||||
outputs = run_session(session, inputs)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
lora_outputs = self._run_lora_grouped_gemm(tensors)
|
||||
|
||||
torch.testing.assert_close(outputs['output'],
|
||||
lora_outputs,
|
||||
atol=5e-3,
|
||||
rtol=0.3)
|
||||
Loading…
Reference in New Issue
Block a user