added loraOp into lora layer + test for mlp and comparison to lora plugin (#3455)

Loraop integration into torch modules

Signed-off-by: Ubuntu <dafrimi@nvidia.com>
This commit is contained in:
danielafrimi 2025-04-17 07:48:27 +03:00 committed by GitHub
parent 239fe0ff26
commit 0f084d9566
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 680 additions and 263 deletions

View File

@ -70,7 +70,8 @@ add_library(
userbuffersFinalizeOp.cpp
userbuffersTensor.cpp
weightOnlyQuantOp.cpp
mtpOp.cpp)
mtpOp.cpp
loraOp.cpp)
set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils
${Python3_LIBRARIES} ${SHARED_TARGET})

View File

@ -0,0 +1,192 @@
/*
* Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/cublasMMWrapper.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/opUtils.h"
#include "tensorrt_llm/kernels/lora/lora.h"
#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h"
#include "tensorrt_llm/thop/thUtils.h"
namespace th = torch;
namespace tk = tensorrt_llm::kernels;
using tensorrt_llm::common::fmtstr;
namespace torch_ext
{
enum class RequestType : int32_t
{
kCONTEXT = 0,
kGENERATION = 1
};
int64_t getNumTokens(th::Tensor const& input)
{
int ndim = input.sizes().size();
TLLM_CHECK_WITH_INFO(
3 == ndim || 2 == ndim, "hidden_state dimension should be either 2 [numTokens, hidden], or 3 [b, s, hidden]");
int64_t num_tokens = input.sizes()[0];
if (ndim == 3)
{
num_tokens *= input.sizes()[1];
}
return num_tokens;
}
std::vector<th::Tensor> lora_grouped_gemm(th::Tensor const& input, th::Tensor const& host_request_types,
std::vector<th::Tensor> const& lora_ranks, // numModules tensors, each tensors has single value
std::vector<th::Tensor> const& lora_weights_pointers, th::Tensor const& host_context_lengths,
std::vector<int64_t> const& output_hidden_sizes, bool transA, bool transB, int64_t const max_low_rank,
int64_t const& weight_index, bool isRemoveInputPadding)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto stream = at::cuda::getCurrentCUDAStream().stream();
auto const numReqs = lora_ranks[0].sizes()[0];
auto const out_shape = input.sizes();
int const numLoraModules = lora_ranks.size();
TLLM_CHECK_WITH_INFO(lora_ranks.size() == lora_weights_pointers.size(), "both should be numLoraModules");
std::vector<th::Tensor> output_torch;
for (int i = 0; i < numLoraModules; i++)
{
std::vector<int64_t> output_shape = {out_shape[0], output_hidden_sizes[i]};
if (!isRemoveInputPadding)
{
output_shape = {out_shape[0], out_shape[1], output_hidden_sizes[i]};
}
output_torch.push_back(torch::empty(output_shape, input.options()));
}
std::vector<void*> output;
for (auto tensor_it = output_torch.begin(); tensor_it != output_torch.end(); tensor_it++)
{
output.push_back(tensor_it->data_ptr());
}
int const seqLen = isRemoveInputPadding ? 0 : input.sizes()[1];
int32_t const* reqTypes = static_cast<int32_t const*>(host_request_types.data_ptr());
int32_t const* hostContextLengths
= isRemoveInputPadding ? static_cast<int32_t const*>(host_context_lengths.data_ptr()) : nullptr;
int64_t numTokens = getNumTokens(input);
std::vector<void const*> expandLoraWeightPtrs{};
std::vector<int32_t> expandLoraRanks{};
expandLoraWeightPtrs.reserve(numLoraModules * numTokens * 2);
expandLoraRanks.reserve(numLoraModules * numTokens);
for (int loraModuleIdx = 0; loraModuleIdx < numLoraModules; loraModuleIdx++)
{
auto const loraRankModule = static_cast<int32_t const*>(lora_ranks[loraModuleIdx].data_ptr());
auto const loraWeightModulePtrs = static_cast<int64_t const*>(lora_weights_pointers[loraModuleIdx].data_ptr());
int idx = 0;
for (int reqId = 0; reqId < numReqs; reqId++)
{
// loraWeightModulePtrs has 3 pointers for each module: A,B, and an optional DoRA magnitude
// the current LoRA plugin does not apply DoRA scaling, so the magnitude is ignored
RequestType const reqType = static_cast<RequestType const>(reqTypes[reqId]);
if (reqType == RequestType::kGENERATION)
{
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3]));
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3 + 1]));
expandLoraRanks.push_back(loraRankModule[reqId]);
idx += 1;
}
else
{
int contextLen = (isRemoveInputPadding ? hostContextLengths[reqId] : seqLen);
for (int contextId = 0; contextId < contextLen; contextId++)
{
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3]));
expandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 3 + 1]));
expandLoraRanks.push_back(loraRankModule[reqId]);
idx += 1;
}
}
}
// In 1st generation phase cross attention qkv lora, cross qkv is skipped by passing an empty encoder_output
// (passing 0 to dim) getNumTokens() will get in cross qkv_lora. Skipping the check for this case.
if (numTokens > 0)
{
TLLM_CHECK_WITH_INFO(idx == numTokens,
fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %ld", idx, numTokens));
}
}
auto cublasHandle = getCublasHandle();
auto cublasLtHandle = getCublasLtHandle();
auto cublasWraper
= std::make_shared<tensorrt_llm::common::CublasMMWrapper>(cublasHandle, cublasLtHandle, nullptr, nullptr);
int const inHiddenSize = input.sizes()[input.sizes().size() - 1];
std::vector<int> outHiddenSizes(output_hidden_sizes.size());
for (int i = 0; i < numLoraModules; i++)
{
outHiddenSizes[i] = output_hidden_sizes[i];
}
nvinfer1::DataType loraRuntimeDataType;
switch (input.scalar_type())
{
case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break;
case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break;
default: throw std::invalid_argument("Invalid dtype, only supports float16, bfloat16");
}
auto mLoraImpl = std::make_shared<tensorrt_llm::kernels::LoraImpl>(
inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper);
// TODO (dafrimi): use Profiler to find the best tactic as used in lora_plugin
mLoraImpl->setBestTactic(std::nullopt);
auto const workspace_size = mLoraImpl->getWorkspaceSize(numTokens, numReqs, loraRuntimeDataType);
auto workspace = torch::empty(std::vector<int64_t>{static_cast<int64_t>(workspace_size)}, input.options());
mLoraImpl->run(numTokens, numReqs, input.data_ptr(), expandLoraRanks.data(), expandLoraWeightPtrs.data(),
weight_index, output.data(), workspace.data_ptr(), stream);
sync_check_cuda_error(stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return output_torch;
}
} // namespace torch_ext
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"lora_grouped_gemm(Tensor input, "
"Tensor host_request_types, "
"Tensor [] lora_ranks, "
"Tensor [] lora_weights_pointers, "
"Tensor host_context_lengths, "
"int [] output_hidden_sizes, "
"bool transA, "
"bool transB, "
"int max_low_rank, "
"int weight_index, "
"bool isRemoveInputPadding) -> Tensor[]");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("lora_grouped_gemm", &torch_ext::lora_grouped_gemm);
}

View File

@ -95,90 +95,61 @@ class LoraLayer(torch.nn.Module):
def forward(self, x, lora_params: Dict,
layer_idx: int) -> Optional[torch.Tensor]:
if bool(lora_params):
lora_ranks = []
lora_weight_pointers = []
lora_weight_tensors = [
] # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr
active_lora_module_ids = []
for module_idx in self.lora_module_types:
module_idx = int(module_idx)
if module_idx in lora_params[layer_idx]:
active_lora_module_ids.append(module_idx)
is_dora = lora_params[layer_idx][module_idx][
'is_dora'] # todo (dafrimi) use it when calling loraOP
# TODO (dafrimi): needs to pass this is_dora arg
lora_params[layer_idx][module_idx]['is_dora']
lora_ranks.append(
lora_params[layer_idx][module_idx]['adapter_size'])
lora_weight_pointers.append(
lora_params[layer_idx][module_idx]['weight_pointers'])
lora_weight_tensors.append(
lora_params[layer_idx][module_idx]['weight_tensors']
) # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr
lora_params['num_seqs']
num_seqs = lora_params['num_seqs']
if len(active_lora_module_ids) == 0:
return None
else:
# If there's only one module, use the existing implementation
if len(active_lora_module_ids) == 1:
lora_weight_tensors = lora_weight_tensors[0]
lora_output = (
x @ lora_weight_tensors[1].T) @ lora_weight_tensors[0].T
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
x,
lora_params['host_request_types'][:num_seqs],
lora_ranks,
lora_weight_pointers,
lora_params['prompt_lens_cpu'][:num_seqs],
self.output_hidden_sizes,
False, # transA
True, # transB
max([r.max() for r in lora_ranks]),
0,
lora_params["remove_input_padding"],
)
if isinstance(lora_outputs, torch.Tensor):
return lora_outputs
else:
# For multiple LoRA modules, some might not be executed in grouped gemm.
# For those modules not executed, we create zero tensors with matching dimensions.
# Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
lora_output = []
for module_idx in self.lora_module_types:
if int(module_idx) in active_lora_module_ids:
lora_output.append(lora_outputs.pop(0))
else:
lora_output.append(
torch.zeros(list(x.shape[:-1]) + [
self.output_hidden_sizes[
self.lora_module_types.index(
module_idx)]
],
dtype=x.dtype,
device=x.device))
lora_output = torch.cat(lora_output, dim=-1)
return lora_output
# For multiple modules, compute and concatenate outputs
lora_outputs = []
for module_idx in self.lora_module_types:
module_idx = int(module_idx)
if module_idx in active_lora_module_ids:
i = active_lora_module_ids.index(module_idx)
weight_tensors = lora_weight_tensors[i]
A, B = weight_tensors[0], weight_tensors[1]
lora_output = (x @ B.T) @ A.T
lora_outputs.append(lora_output)
# Concatenate outputs from all modules
lora_output = torch.cat(lora_outputs, dim=-1)
return lora_output
# TODO(dafrimi): use torch implementation. For now, this is just a placeholder, until we will do the biniding to lora ops C++
# lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
# x,
# lora_params['host_request_types'][:num_seqs],
# lora_ranks,
# lora_weight_pointers,
# lora_params['prompt_lens_cpu'][:num_seqs],
# self.output_hidden_sizes,
# False, # transA
# True, # transB
# max([r.max() for r in lora_ranks]),
# 0,
# True,
# )
# if isinstance(lora_outputs, torch.Tensor):
# return lora_outputs
# else:
# # For multiple LoRA modules, some might not be executed in grouped gemm.
# # For those modules not executed, we create zero tensors with matching dimensions.
# # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order.
# lora_output = []
# for module_idx in self.lora_module_types:
# if int(module_idx) in active_lora_module_ids:
# lora_output.append(lora_outputs.pop(0))
# else:
# lora_output.append(
# torch.zeros(list(x.shape[:-1]) + [
# self.output_hidden_sizes[
# self.lora_module_types.index(
# module_idx)]
# ],
# dtype=x.dtype,
# device=x.device))
# lora_output = torch.cat(lora_output, dim=-1)
# return lora_output
else:
return None

View File

@ -12,11 +12,13 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_llama import LlamaAttention
# LoRA Imports
from tensorrt_llm._torch.peft.lora.layer import LoraModuleType
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.layers import Attention
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.layers import AttentionParams, KeyValueCacheParams
from tensorrt_llm.layers.lora import Lora, LoraParams
from tensorrt_llm.mapping import Mapping
@ -27,94 +29,34 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
def setUpClass(cls):
cls.batch_size = 1
cls.seq_len = 16
cls.hidden_size = 64
cls.head_num = 1
cls.num_hidden_layers = 1
cls.head_size = 64
cls.hidden_size = cls.head_num * cls.head_size
cls.dtype = 'float16'
cls.torch_dtype = str_dtype_to_torch(cls.dtype)
cls.device = torch.device('cuda')
cls.pos_emb_type = PositionEmbeddingType.learned_absolute
cls.causal_mask = True
# KV cache parameters
cls.num_blocks = 4
cls.tokens_per_block = 32
def _create_lora_params(self, ):
lora_ranks_list = [8 for _ in range(self.batch_size)]
cls.llama_config = LlamaConfig(hidden_size=cls.hidden_size,
num_attention_heads=cls.head_num,
num_hidden_layers=cls.num_hidden_layers,
intermediate_size=256,
max_position_embeddings=512,
rms_norm_eps=1e-5,
vocab_size=32000,
num_key_value_heads=cls.head_num,
torch_dtype=cls.torch_dtype)
# Create KV cache manager
head_dim = cls.llama_config.hidden_size // cls.llama_config.num_attention_heads
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=cls.num_blocks *
cls.tokens_per_block)
cls.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=cls.llama_config.num_hidden_layers,
num_kv_heads=cls.llama_config.num_key_value_heads,
head_dim=head_dim,
tokens_per_block=cls.tokens_per_block,
max_seq_len=cls.num_blocks * cls.tokens_per_block,
max_batch_size=cls.batch_size,
mapping=mapping,
dtype=tensorrt_llm.bindings.DataType.HALF)
@classmethod
def tearDownClass(cls):
cls.kv_cache_manager.shutdown()
def _create_attention_inputs(self):
hidden_states = torch.empty(
size=[self.batch_size, self.seq_len, self.hidden_size],
dtype=self.torch_dtype,
device='cuda')
hidden_states.normal_(0.0, 0.02)
# Create weights
q_weight = torch.empty(size=[self.hidden_size, self.hidden_size],
dtype=self.torch_dtype)
torch.nn.init.xavier_uniform_(q_weight)
# Set K and V and O weights to identity matrix
eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype)
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
out_weight = eye_weight
return hidden_states, qkv_weight, out_weight
def _create_lora_params(self):
lora_ranks_list = [8]
host_context_lengths = torch.Tensor(
[self.seq_len for _ in range(self.batch_size)]).to(torch.int32)
lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32)
host_request_types = torch.zeros_like(host_context_lengths,
device='cpu').int()
lora_weight_ins = [
torch.randn(self.hidden_size, lora_rank, device=self.device).to(
torch.randn(self.hidden_size, lora_rank, device="cuda").to(
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
]
lora_weight_outs = [
torch.randn(lora_rank, self.hidden_size, device=self.device).to(
torch.randn(lora_rank, self.hidden_size, device="cuda").to(
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
]
lora_weight_ins = [
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_ins
]
lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins]
lora_weight_outs = [
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs
]
# Create weight pointers for TensorRT
lora_weights_pointers = []
for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs):
lora_weights_pointers.append(in_ptr.data_ptr())
@ -125,88 +67,164 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
return {
'lora_ranks': lora_ranks,
'host_context_lengths': host_context_lengths,
'host_request_types': host_request_types,
'lora_weights_pointers': lora_weights_pointers,
'lora_weight_ins': lora_weight_ins,
'lora_weight_outs': lora_weight_outs
}
def _setup_attention_module(self, qkv_weight, out_weight):
"""Set up the attention module with weights."""
model_config = ModelConfig(pretrained_config=self.llama_config,
attn_backend="VANILLA")
layer_idx = 0
attention_module = LlamaAttention(model_config, layer_idx=layer_idx).to(
self.device).to(self.torch_dtype)
def test_lora_attention(self):
# Set weights
attention_module.qkv_proj.weight.data = torch.from_numpy(
np.ascontiguousarray(qkv_weight.cpu().numpy().transpose(
[1, 0]))).to(self.device)
attention_module.o_proj.weight.data = torch.from_numpy(
np.ascontiguousarray(out_weight.cpu().numpy().transpose(
[1, 0]))).to(self.device)
mean = 0.0
std_dev = 0.02 if self.dtype == "float32" else 0.005
return attention_module, model_config
hidden_states = torch.concat([
torch.empty(size=[self.seq_len, self.hidden_size],
dtype=self.torch_dtype,
device=self.device).normal_(mean, std_dev)
for _ in range(self.batch_size)
])
def _create_attention_metadata(self, model_config):
sequence_lengths = [self.seq_len]
past_seen_tokens = [0]
request_ids = [0]
token_nums = [self.seq_len]
prompt_lens = token_nums
context_lengths = torch.full([self.batch_size],
self.seq_len,
dtype=torch.int32,
device=self.device)
self.kv_cache_manager.add_dummy_requests(request_ids, token_nums)
# Plugin specific setup - only generate 1 step
max_seq_len = self.seq_len + 1
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
return metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32),
num_contexts=len(sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=self.kv_cache_manager,
request_ids=request_ids,
prompt_lens=prompt_lens,
max_num_requests=self.batch_size,
max_num_tokens=self.batch_size * self.seq_len,
)
# zero means "valid" token, one means invalid.
host_past_key_value_lengths = torch.tensor([0] * self.batch_size,
dtype=torch.int32)
# the max kv cache length for each layer. single tensor since we only have 1 layer here.
host_max_attention_window_sizes = torch.tensor([max_seq_len],
dtype=torch.int32)
host_sink_token_length = torch.tensor([0], dtype=torch.int32)
sequence_length = torch.full([self.batch_size],
self.seq_len,
dtype=torch.int32,
device=self.device)
# even in the the context phase, kv cache tensors can not be empty tensor for plugin, the actual shape info
# otherwise, there will be cublas execution error.
# are passed to plugin by the `sequence_length` tensor
kv_shape = (self.batch_size, 2, self.head_num, max_seq_len,
self.head_size)
past_key_value = torch.randn(kv_shape,
dtype=self.torch_dtype,
device=self.device)
cache_indirection = torch.full((
self.batch_size,
1,
max_seq_len,
),
0,
dtype=torch.int32,
device=self.device)
host_request_types = torch.tensor([0] * self.batch_size,
dtype=torch.int32,
device='cpu')
perf_knob_tensor_size = 16
host_runtime_perf_knobs_tensor = torch.tensor([-1] *
perf_knob_tensor_size,
dtype=torch.int64,
device='cpu')
host_context_progress = torch.tensor([0],
dtype=torch.int64,
device='cpu')
host_context_lengths = torch.Tensor(
[self.seq_len for _ in range(self.batch_size)]).to(torch.int32)
q_weight = torch.empty(size=[self.hidden_size, self.hidden_size],
dtype=self.torch_dtype)
torch.nn.init.xavier_uniform_(q_weight)
# The initialization here is chosen to minimize computation after the
# QKV BMMs in order to reduce the amount of differences from FP accumulation.
# We set K and V weights to the identity matrix so that the input is copied
# without doing any accumulation. Additionally, we set the output projection
# to the identity for the same reason.
# The main purpose of these tests is to check the QK^T BMM + Softmax + SV BMM for LoRA.
eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype)
qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1)
out_weight = eye_weight
lora_params = self._create_lora_params()
def _setup_trt_network(self, hidden_states, lora_params, attention_module):
builder = tensorrt_llm.Builder()
net = builder.create_network()
net.plugin_config.to_legacy_setting()
net.plugin_config.lora_plugin = self.dtype
net.plugin_config.gpt_attention_plugin = self.dtype # for ragged input we use this plugin with remove_input_padding
net.plugin_config.remove_input_padding = True
net.plugin_config.lora_plugin = "float16"
with tensorrt_llm.net_guard(net):
# Create input tensor
trt_hidden_states = Tensor(name='hidden_states',
shape=hidden_states.shape,
dtype=tensorrt_llm.str_dtype_to_trt(
self.dtype))
# Create LoRA tensors
host_request_types_tensor = Tensor(
name='host_request_types',
shape=[lora_params['host_request_types'].shape[0]],
context_lengths_tensor = Tensor(
name='context_lengths',
shape=context_lengths.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=[lora_params['host_context_lengths'].shape[0]],
shape=[host_context_lengths.shape[0]],
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
lora_ranks_tensor = Tensor(
name='lora_ranks',
shape=[lora_params['lora_ranks'].shape[0]],
shape=(lora_params['lora_ranks'].shape[0], ),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
lora_weights_pointers_tensor = Tensor(
name='lora_weights_pointers',
shape=lora_params['lora_weights_pointers'].shape,
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
# Create LoRA parameters
lora_params = LoraParams(
host_request_types_tensor = Tensor(
name='host_request_types',
shape=host_request_types.shape,
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
past_key_value_tensor = Tensor(name='past_key_value',
shape=tuple(past_key_value.shape),
dtype=tensorrt_llm.str_dtype_to_trt(
self.dtype))
sequence_length_tensor = Tensor(
name='sequence_length',
shape=tuple(sequence_length.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_past_key_value_lengths_tensor = Tensor(
name='host_past_key_value_lengths',
shape=tuple(host_past_key_value_lengths.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_max_attention_window_sizes_tensor = Tensor(
name='host_max_attention_window_sizes',
shape=tuple(host_max_attention_window_sizes.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_sink_token_length_tensor = Tensor(
name='host_sink_token_length',
shape=tuple(host_sink_token_length.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
cache_indirection_tensor = Tensor(
name='cache_indirection',
shape=tuple(cache_indirection.shape),
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_runtime_perf_knobs = Tensor(
name='host_runtime_perf_knobs',
shape=[16],
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
host_context_progress_tensor = Tensor(
name='host_context_progress',
shape=[1],
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
lora_layer_params = LoraParams(
lora_ranks=[{
"attn_q_lora_ranks": lora_ranks_tensor,
"attn_k_lora_ranks": lora_ranks_tensor,
@ -224,15 +242,17 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
lora_weights_pointers_tensor,
}],
host_context_lengths=host_context_lengths_tensor,
host_request_types=host_request_types_tensor)
host_request_types=host_request_types_tensor,
)
attn_layer = Attention(
attn_layer = tensorrt_llm.layers.Attention(
local_layer_idx=0,
hidden_size=hidden_states.shape[-1],
num_attention_heads=1,
max_position_embeddings=hidden_states.shape[1],
hidden_size=self.hidden_size,
num_attention_heads=self.head_num,
max_position_embeddings=self.seq_len,
attention_mask_type=tensorrt_llm.layers.AttentionMaskType.
causal,
position_embedding_type=self.pos_emb_type,
bias=False)
attn_layer.qkv_lora = Lora(
@ -254,29 +274,62 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
max_low_rank=8,
)
# Set attention layer weights
attn_layer.qkv.weight.value = attention_module.qkv_proj.weight.data
attn_layer.dense.weight.value = attention_module.o_proj.weight.data
attn_layer.qkv.weight.value = np.ascontiguousarray(
qkv_weight.cpu().numpy().transpose([1, 0]))
attn_layer.dense.weight.value = np.ascontiguousarray(
out_weight.cpu().numpy().transpose([1, 0]))
output = attn_layer(hidden_states=trt_hidden_states,
lora_layer_params=lora_params)
output, present_key_value = attn_layer(
trt_hidden_states,
use_cache=True,
lora_layer_params=
lora_layer_params, # Always use cache for plugin path in this test
kv_cache_params=KeyValueCacheParams(
past_key_value=[past_key_value_tensor],
host_past_key_value_lengths=
host_past_key_value_lengths_tensor,
host_max_attention_window_sizes=
host_max_attention_window_sizes_tensor,
host_sink_token_length=host_sink_token_length_tensor,
cache_indirection=cache_indirection_tensor),
attention_params=AttentionParams(
sequence_length=sequence_length_tensor,
context_lengths=context_lengths_tensor,
host_request_types=host_request_types_tensor,
max_context_length=self.seq_len,
host_runtime_perf_knobs=host_runtime_perf_knobs,
host_context_progress=host_context_progress_tensor,
host_context_lengths=host_context_lengths_tensor,
))
assert isinstance(output, Tensor)
output.mark_output('output',
tensorrt_llm.str_dtype_to_trt(self.dtype))
present_key_value.mark_output(
'present_key_value', tensorrt_llm.str_dtype_to_trt(self.dtype))
return builder, net
def _run_trt_inference(self, builder, net, hidden_states, lora_params):
builder_config = builder.create_builder_config(name='attention',
builder_config = builder.create_builder_config(name='attention_plugin',
precision=self.dtype)
engine_buffer = builder.build_engine(net, builder_config)
session = tensorrt_llm.runtime.Session.from_serialized_engine(
engine_buffer)
stream = torch.cuda.current_stream().cuda_stream
inputs = {
'hidden_states': hidden_states,
'host_request_types': lora_params['host_request_types'],
'host_context_lengths': lora_params['host_context_lengths'],
'past_key_value': past_key_value,
'sequence_length': sequence_length,
'host_past_key_value_lengths': host_past_key_value_lengths,
'host_max_attention_window_sizes': host_max_attention_window_sizes,
'host_sink_token_length': host_sink_token_length,
'context_lengths': context_lengths,
'host_request_types': host_request_types,
'cache_indirection': cache_indirection,
'host_runtime_perf_knobs': host_runtime_perf_knobs_tensor,
'host_context_progress': host_context_progress,
'host_context_lengths': host_context_lengths,
'lora_ranks': lora_params['lora_ranks'],
'lora_weights_pointers': lora_params['lora_weights_pointers'],
}
@ -286,108 +339,128 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase):
torch.empty(hidden_states.shape,
dtype=tensorrt_llm._utils.str_dtype_to_torch(
self.dtype),
device='cuda'),
device=self.device),
'present_key_value':
past_key_value,
}
session.run(inputs=inputs, outputs=outputs, stream=stream)
torch.cuda.synchronize()
return outputs['output'].squeeze(0)
# Pytorch flow
llama_config = LlamaConfig(hidden_size=self.hidden_size,
num_attention_heads=self.head_num,
num_hidden_layers=1,
intermediate_size=256,
max_position_embeddings=512,
rms_norm_eps=1e-5,
vocab_size=32000,
num_key_value_heads=self.head_num,
torch_dtype=self.torch_dtype)
def test_attention_with_lora(self):
hidden_states, qkv_weight, out_weight = self._create_attention_inputs()
mapping = Mapping(world_size=1, tp_size=1, rank=0)
kv_cache_config = KvCacheConfig(max_tokens=max_seq_len)
head_dim = llama_config.hidden_size // llama_config.num_attention_heads
kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config,
kv_cache_type=tensorrt_llm.bindings.internal.batch_manager.
CacheType.SELF,
num_layers=llama_config.num_hidden_layers,
num_kv_heads=llama_config.num_key_value_heads,
head_dim=head_dim,
tokens_per_block=128,
max_seq_len=max_seq_len,
max_batch_size=self.batch_size,
mapping=mapping,
dtype=tensorrt_llm.bindings.DataType.HALF)
lora_params = self._create_lora_params()
model_config = ModelConfig(pretrained_config=llama_config,
attn_backend="VANILLA")
attention_module = LlamaAttention(model_config, layer_idx=0).to(
self.device).to(self.torch_dtype)
attention_module, model_config = self._setup_attention_module(
qkv_weight, out_weight)
attention_module.qkv_proj.weight.data = torch.from_numpy(
np.ascontiguousarray(qkv_weight.cpu().numpy().transpose(
[1, 0]))).to(self.device)
attention_module.o_proj.weight.data = torch.from_numpy(
np.ascontiguousarray(out_weight.cpu().numpy().transpose(
[1, 0]))).to(self.device)
attn_metadata = self._create_attention_metadata(model_config)
builder, net = self._setup_trt_network(hidden_states, lora_params,
attention_module)
trt_output = self._run_trt_inference(builder, net, hidden_states,
lora_params)
request_ids = [0]
lora_params_pytorchflow = {
kv_cache_manager.add_dummy_requests(request_ids=request_ids,
token_nums=[self.seq_len])
sequence_lengths = [self.seq_len]
past_seen_tokens = [0]
metadata_cls = get_attention_backend(model_config.attn_backend).Metadata
attn_metadata = metadata_cls(
seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32),
num_contexts=len(sequence_lengths),
kv_cache_params=KVCacheParams(
use_cache=True,
num_cached_tokens_per_seq=past_seen_tokens,
),
kv_cache_manager=kv_cache_manager,
request_ids=request_ids,
prompt_lens=sequence_lengths,
max_num_requests=self.batch_size,
max_num_tokens=self.batch_size * self.seq_len,
)
lora_params_pytorch_flow = {
'num_seqs': self.batch_size,
'host_request_types': torch.zeros(self.batch_size,
dtype=torch.int32),
'prompt_lens_cpu': torch.tensor([self.seq_len] * self.batch_size),
'host_request_types':host_request_types,
'prompt_lens_cpu': host_context_lengths,
'remove_input_padding': True,
0: { # layer_idx
LoraModuleType.ATTENTION_Q: { # Module type
'adapter_size':
torch.tensor([8]),
lora_params['lora_ranks'],
'weight_pointers':
torch.tensor([[
lora_params['lora_weight_outs'][0].data_ptr(),
lora_params['lora_weight_ins'][0].data_ptr()
]]),
lora_params['lora_weights_pointers'],
'is_dora':
False,
'weight_tensors': [
lora_params['lora_weight_outs'][0],
lora_params['lora_weight_ins'][0]
]
},
LoraModuleType.ATTENTION_K: {
'adapter_size':
torch.tensor([8]),
'weight_pointers':
torch.tensor([[
lora_params['lora_weight_outs'][0].data_ptr(),
lora_params['lora_weight_ins'][0].data_ptr()
]]),
lora_params['lora_ranks'],
'weight_pointers': lora_params['lora_weights_pointers'],
'is_dora':
False,
'weight_tensors': [
lora_params['lora_weight_outs'][0],
lora_params['lora_weight_ins'][0]
]
},
LoraModuleType.ATTENTION_V: {
'adapter_size':
torch.tensor([8]),
lora_params['lora_ranks'],
'weight_pointers':
torch.tensor([[
lora_params['lora_weight_outs'][0].data_ptr(),
lora_params['lora_weight_ins'][0].data_ptr()
]]),
lora_params['lora_weights_pointers'],
'is_dora':
False,
'weight_tensors': [
lora_params['lora_weight_outs'][0],
lora_params['lora_weight_ins'][0]
]
},
LoraModuleType.ATTENTION_DENSE: {
'adapter_size':
torch.tensor([8]),
lora_params['lora_ranks'],
'weight_pointers':
torch.tensor([[
lora_params['lora_weight_outs'][0].data_ptr(),
lora_params['lora_weight_ins'][0].data_ptr()
]]),
lora_params['lora_weights_pointers'],
'is_dora':
False,
'weight_tensors': [
lora_params['lora_weight_outs'][0],
lora_params['lora_weight_ins'][0]
]
}
}
}
with torch.inference_mode():
attn_metadata.prepare()
hidden_states_pytorchflow = hidden_states.squeeze(0)
pytorchflow_output = attention_module(
pytorch_flow_output = attention_module(
position_ids=None,
hidden_states=hidden_states_pytorchflow,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
attention_mask=PredefinedAttentionMask.CAUSAL,
lora_params=lora_params_pytorchflow)
lora_params=lora_params_pytorch_flow)
torch.testing.assert_close(pytorchflow_output,
trt_output = outputs['output']
torch.testing.assert_close(pytorch_flow_output,
trt_output,
atol=2e-3,
rtol=0)

View File

@ -0,0 +1,180 @@
import os
import sys
import unittest
import torch
import tensorrt_llm
from tensorrt_llm import Tensor
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from utils.util import create_session, run_session
class TestLoraPluginVsLayer(unittest.TestCase):
def setUp(self):
tensorrt_llm.logger.set_level('info')
torch.random.manual_seed(0)
self.dtype = 'float16'
self.torch_dtype = torch.float16
self.device = 'cuda'
self.batch_size = 4
self.seq_len = 8
self.hidden_size = 1024
self.lora_rank = 8
self.is_remove_input_padding = False
self.weight_index = 0
self.transA = False
self.transB = True
def _create_input_tensors(self, batch_size, seq_len, hidden_size,
lora_ranks_list):
input_tensor = torch.randn(batch_size,
seq_len,
hidden_size,
dtype=self.torch_dtype,
device=self.device) * 0.1
lora_weight_ins = [
torch.randn(hidden_size, lora_rank, device=self.device).to(
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
]
lora_weight_outs = [
torch.randn(lora_rank, hidden_size, device=self.device).to(
self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list
]
lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins]
lora_weight_outs = [
tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs
]
# Create LoRA weight pointers
lora_weights_pointers = []
for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs):
lora_weights_pointers.append(in_ptr.data_ptr())
lora_weights_pointers.append(out_ptr.data_ptr())
# null dora scale
lora_weights_pointers.append(0)
lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to(
torch.int64).reshape([batch_size, 3])
# Create other tensors
host_context_lengths = torch.Tensor(
[seq_len for _ in range(batch_size)]).to(torch.int32)
lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32)
host_request_types = torch.zeros_like(host_context_lengths,
device='cpu').int()
return {
'input_tensor': input_tensor,
'lora_weight_ins': lora_weight_ins,
'lora_weight_outs': lora_weight_outs,
'lora_weights_pointers': lora_weights_pointers,
'host_context_lengths': host_context_lengths,
'lora_ranks': lora_ranks,
'host_request_types': host_request_types,
'batch_size': batch_size,
'seq_len': seq_len,
'hidden_size': hidden_size,
'max_lora_rank': max(max(lora_ranks_list), 8)
}
def _create_lora_plugin_session(self, tensors):
# Construct TensorRT network
builder = tensorrt_llm.Builder()
network = builder.create_network()
network.plugin_config.set_lora_plugin(self.dtype)
network.plugin_config.remove_input_padding = self.is_remove_input_padding
with tensorrt_llm.net_guard(network):
input_tensor = Tensor(name='input_tensor',
shape=[
tensors['batch_size'], tensors['seq_len'],
tensors['hidden_size']
],
dtype=tensorrt_llm.str_dtype_to_trt(
self.dtype))
host_request_types_tensor = Tensor(
name='host_request_types',
shape=[tensors['batch_size']],
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
host_context_lengths_tensor = Tensor(
name='host_context_lengths',
shape=[tensors['batch_size']],
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
lora_ranks_tensor = Tensor(
name='lora_ranks',
shape=[tensors['batch_size']],
dtype=tensorrt_llm.str_dtype_to_trt('int32'))
lora_weights_pointers_tensor = Tensor(
name='lora_weights_pointers',
shape=[tensors['batch_size'], 3],
dtype=tensorrt_llm.str_dtype_to_trt('int64'))
output = tensorrt_llm.functional.lora_plugin(
input_tensor,
tensors['hidden_size'],
[tensors['hidden_size']],
host_request_types_tensor,
self.transA, # transA
self.transB, # transB
host_context_lengths_tensor,
tensors['max_lora_rank'],
[lora_ranks_tensor],
[lora_weights_pointers_tensor],
weight_index=self.weight_index,
)
output.mark_output('output')
return create_session(builder, network, precision=self.dtype)
def _run_lora_grouped_gemm(self, tensors):
"""Run the lora_grouped_gemm operation directly"""
# Prepare parameters for lora_grouped_gemm
x = tensors['input_tensor']
host_request_types = tensors[
'host_request_types'][:tensors['batch_size']]
lora_ranks = tensors['lora_ranks']
lora_weight_pointers = tensors['lora_weights_pointers']
prompt_lens_cpu = tensors[
'host_context_lengths'][:tensors['batch_size']]
output_hidden_sizes = [tensors['hidden_size']]
transA = self.transA
transB = self.transB
max_rank = max([r.item() for r in lora_ranks])
weight_index = self.weight_index
is_remove_input_padding = self.is_remove_input_padding
lora_outputs = torch.ops.trtllm.lora_grouped_gemm(
x, host_request_types, [lora_ranks], [lora_weight_pointers],
prompt_lens_cpu, output_hidden_sizes, transA, transB, max_rank,
weight_index, is_remove_input_padding)
return lora_outputs[0]
def test_lora_plugin_vs_lora_op(self):
lora_ranks_list = [self.lora_rank] * self.batch_size
tensors = self._create_input_tensors(self.batch_size, self.seq_len,
self.hidden_size, lora_ranks_list)
session = self._create_lora_plugin_session(tensors)
inputs = {
'input_tensor': tensors['input_tensor'],
'host_request_types': tensors['host_request_types'],
'host_context_lengths': tensors['host_context_lengths'],
'lora_ranks': tensors['lora_ranks'],
'lora_weights_pointers': tensors['lora_weights_pointers'],
}
outputs = run_session(session, inputs)
torch.cuda.synchronize()
lora_outputs = self._run_lora_grouped_gemm(tensors)
torch.testing.assert_close(outputs['output'],
lora_outputs,
atol=5e-3,
rtol=0.3)