From 0f084d9566e3163b3233d562776454555e9c145f Mon Sep 17 00:00:00 2001 From: danielafrimi <45691845+danielafrimi@users.noreply.github.com> Date: Thu, 17 Apr 2025 07:48:27 +0300 Subject: [PATCH] added loraOp into lora layer + test for mlp and comparison to lora plugin (#3455) Loraop integration into torch modules Signed-off-by: Ubuntu --- cpp/tensorrt_llm/thop/CMakeLists.txt | 3 +- cpp/tensorrt_llm/thop/loraOp.cpp | 192 ++++++++ tensorrt_llm/_torch/peft/lora/layer.py | 105 ++-- ...test_lora_attention_pytorch_flow_vs_trt.py | 463 ++++++++++-------- .../test_lora_plugin_vs_lora_op.py | 180 +++++++ 5 files changed, 680 insertions(+), 263 deletions(-) create mode 100644 cpp/tensorrt_llm/thop/loraOp.cpp create mode 100644 tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_lora_op.py diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 65280bcace..17ee60bc11 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -70,7 +70,8 @@ add_library( userbuffersFinalizeOp.cpp userbuffersTensor.cpp weightOnlyQuantOp.cpp - mtpOp.cpp) + mtpOp.cpp + loraOp.cpp) set_property(TARGET th_common PROPERTY POSITION_INDEPENDENT_CODE ON) target_link_libraries(th_common PRIVATE ${TORCH_LIBRARIES} th_utils ${Python3_LIBRARIES} ${SHARED_TARGET}) diff --git a/cpp/tensorrt_llm/thop/loraOp.cpp b/cpp/tensorrt_llm/thop/loraOp.cpp new file mode 100644 index 0000000000..973628725a --- /dev/null +++ b/cpp/tensorrt_llm/thop/loraOp.cpp @@ -0,0 +1,192 @@ + +/* + * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/cublasMMWrapper.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/kernels/lora/lora.h" +#include "tensorrt_llm/kernels/selectiveScan/selectiveScan.h" +#include "tensorrt_llm/thop/thUtils.h" + +namespace th = torch; +namespace tk = tensorrt_llm::kernels; +using tensorrt_llm::common::fmtstr; + +namespace torch_ext +{ + +enum class RequestType : int32_t +{ + kCONTEXT = 0, + kGENERATION = 1 +}; + +int64_t getNumTokens(th::Tensor const& input) +{ + int ndim = input.sizes().size(); + TLLM_CHECK_WITH_INFO( + 3 == ndim || 2 == ndim, "hidden_state dimension should be either 2 [numTokens, hidden], or 3 [b, s, hidden]"); + int64_t num_tokens = input.sizes()[0]; + if (ndim == 3) + { + num_tokens *= input.sizes()[1]; + } + return num_tokens; +} + +std::vector lora_grouped_gemm(th::Tensor const& input, th::Tensor const& host_request_types, + std::vector const& lora_ranks, // numModules tensors, each tensors has single value + std::vector const& lora_weights_pointers, th::Tensor const& host_context_lengths, + std::vector const& output_hidden_sizes, bool transA, bool transB, int64_t const max_low_rank, + int64_t const& weight_index, bool isRemoveInputPadding) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto const numReqs = lora_ranks[0].sizes()[0]; + auto const out_shape = input.sizes(); + int const numLoraModules = lora_ranks.size(); + TLLM_CHECK_WITH_INFO(lora_ranks.size() == lora_weights_pointers.size(), "both should be numLoraModules"); + std::vector output_torch; + for (int i = 0; i < numLoraModules; i++) + { + std::vector output_shape = {out_shape[0], output_hidden_sizes[i]}; + if (!isRemoveInputPadding) + { + output_shape = {out_shape[0], out_shape[1], output_hidden_sizes[i]}; + } + output_torch.push_back(torch::empty(output_shape, input.options())); + } + std::vector output; + for (auto tensor_it = output_torch.begin(); tensor_it != output_torch.end(); tensor_it++) + { + output.push_back(tensor_it->data_ptr()); + } + int const seqLen = isRemoveInputPadding ? 0 : input.sizes()[1]; + int32_t const* reqTypes = static_cast(host_request_types.data_ptr()); + int32_t const* hostContextLengths + = isRemoveInputPadding ? static_cast(host_context_lengths.data_ptr()) : nullptr; + + int64_t numTokens = getNumTokens(input); + + std::vector expandLoraWeightPtrs{}; + std::vector expandLoraRanks{}; + + expandLoraWeightPtrs.reserve(numLoraModules * numTokens * 2); + expandLoraRanks.reserve(numLoraModules * numTokens); + + for (int loraModuleIdx = 0; loraModuleIdx < numLoraModules; loraModuleIdx++) + { + auto const loraRankModule = static_cast(lora_ranks[loraModuleIdx].data_ptr()); + auto const loraWeightModulePtrs = static_cast(lora_weights_pointers[loraModuleIdx].data_ptr()); + + int idx = 0; + for (int reqId = 0; reqId < numReqs; reqId++) + { + // loraWeightModulePtrs has 3 pointers for each module: A,B, and an optional DoRA magnitude + // the current LoRA plugin does not apply DoRA scaling, so the magnitude is ignored + RequestType const reqType = static_cast(reqTypes[reqId]); + if (reqType == RequestType::kGENERATION) + { + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3])); + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3 + 1])); + expandLoraRanks.push_back(loraRankModule[reqId]); + idx += 1; + } + else + { + int contextLen = (isRemoveInputPadding ? hostContextLengths[reqId] : seqLen); + + for (int contextId = 0; contextId < contextLen; contextId++) + { + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3])); + expandLoraWeightPtrs.push_back(reinterpret_cast(loraWeightModulePtrs[reqId * 3 + 1])); + expandLoraRanks.push_back(loraRankModule[reqId]); + idx += 1; + } + } + } + + // In 1st generation phase cross attention qkv lora, cross qkv is skipped by passing an empty encoder_output + // (passing 0 to dim) getNumTokens() will get in cross qkv_lora. Skipping the check for this case. + if (numTokens > 0) + { + TLLM_CHECK_WITH_INFO(idx == numTokens, + fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %ld", idx, numTokens)); + } + } + + auto cublasHandle = getCublasHandle(); + auto cublasLtHandle = getCublasLtHandle(); + auto cublasWraper + = std::make_shared(cublasHandle, cublasLtHandle, nullptr, nullptr); + + int const inHiddenSize = input.sizes()[input.sizes().size() - 1]; + + std::vector outHiddenSizes(output_hidden_sizes.size()); + for (int i = 0; i < numLoraModules; i++) + { + outHiddenSizes[i] = output_hidden_sizes[i]; + } + nvinfer1::DataType loraRuntimeDataType; + switch (input.scalar_type()) + { + case torch::kFloat16: loraRuntimeDataType = nvinfer1::DataType::kHALF; break; + case torch::kBFloat16: loraRuntimeDataType = nvinfer1::DataType::kBF16; break; + default: throw std::invalid_argument("Invalid dtype, only supports float16, bfloat16"); + } + + auto mLoraImpl = std::make_shared( + inHiddenSize, outHiddenSizes, transA, transB, numLoraModules, loraRuntimeDataType, max_low_rank, cublasWraper); + + // TODO (dafrimi): use Profiler to find the best tactic as used in lora_plugin + mLoraImpl->setBestTactic(std::nullopt); + + auto const workspace_size = mLoraImpl->getWorkspaceSize(numTokens, numReqs, loraRuntimeDataType); + + auto workspace = torch::empty(std::vector{static_cast(workspace_size)}, input.options()); + + mLoraImpl->run(numTokens, numReqs, input.data_ptr(), expandLoraRanks.data(), expandLoraWeightPtrs.data(), + weight_index, output.data(), workspace.data_ptr(), stream); + sync_check_cuda_error(stream); + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); + return output_torch; +} + +} // namespace torch_ext + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def( + "lora_grouped_gemm(Tensor input, " + "Tensor host_request_types, " + "Tensor [] lora_ranks, " + "Tensor [] lora_weights_pointers, " + "Tensor host_context_lengths, " + "int [] output_hidden_sizes, " + "bool transA, " + "bool transB, " + "int max_low_rank, " + "int weight_index, " + "bool isRemoveInputPadding) -> Tensor[]"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("lora_grouped_gemm", &torch_ext::lora_grouped_gemm); +} diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py index 8f43f245db..11ad8db3fa 100644 --- a/tensorrt_llm/_torch/peft/lora/layer.py +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -95,90 +95,61 @@ class LoraLayer(torch.nn.Module): def forward(self, x, lora_params: Dict, layer_idx: int) -> Optional[torch.Tensor]: if bool(lora_params): + lora_ranks = [] lora_weight_pointers = [] - lora_weight_tensors = [ - ] # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr active_lora_module_ids = [] for module_idx in self.lora_module_types: module_idx = int(module_idx) if module_idx in lora_params[layer_idx]: active_lora_module_ids.append(module_idx) - - is_dora = lora_params[layer_idx][module_idx][ - 'is_dora'] # todo (dafrimi) use it when calling loraOP - + # TODO (dafrimi): needs to pass this is_dora arg + lora_params[layer_idx][module_idx]['is_dora'] lora_ranks.append( lora_params[layer_idx][module_idx]['adapter_size']) lora_weight_pointers.append( lora_params[layer_idx][module_idx]['weight_pointers']) - lora_weight_tensors.append( - lora_params[layer_idx][module_idx]['weight_tensors'] - ) # TODO (dafrimi) needs to delete this when we use loraOps which uses ptr - lora_params['num_seqs'] + num_seqs = lora_params['num_seqs'] if len(active_lora_module_ids) == 0: return None else: - # If there's only one module, use the existing implementation - if len(active_lora_module_ids) == 1: - lora_weight_tensors = lora_weight_tensors[0] - lora_output = ( - x @ lora_weight_tensors[1].T) @ lora_weight_tensors[0].T + lora_outputs = torch.ops.trtllm.lora_grouped_gemm( + x, + lora_params['host_request_types'][:num_seqs], + lora_ranks, + lora_weight_pointers, + lora_params['prompt_lens_cpu'][:num_seqs], + self.output_hidden_sizes, + False, # transA + True, # transB + max([r.max() for r in lora_ranks]), + 0, + lora_params["remove_input_padding"], + ) + if isinstance(lora_outputs, torch.Tensor): + return lora_outputs + else: + # For multiple LoRA modules, some might not be executed in grouped gemm. + # For those modules not executed, we create zero tensors with matching dimensions. + # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order. + lora_output = [] + for module_idx in self.lora_module_types: + if int(module_idx) in active_lora_module_ids: + lora_output.append(lora_outputs.pop(0)) + else: + lora_output.append( + torch.zeros(list(x.shape[:-1]) + [ + self.output_hidden_sizes[ + self.lora_module_types.index( + module_idx)] + ], + dtype=x.dtype, + device=x.device)) + + lora_output = torch.cat(lora_output, dim=-1) return lora_output - # For multiple modules, compute and concatenate outputs - lora_outputs = [] - for module_idx in self.lora_module_types: - module_idx = int(module_idx) - if module_idx in active_lora_module_ids: - i = active_lora_module_ids.index(module_idx) - weight_tensors = lora_weight_tensors[i] - A, B = weight_tensors[0], weight_tensors[1] - lora_output = (x @ B.T) @ A.T - lora_outputs.append(lora_output) - - # Concatenate outputs from all modules - lora_output = torch.cat(lora_outputs, dim=-1) - return lora_output - - # TODO(dafrimi): use torch implementation. For now, this is just a placeholder, until we will do the biniding to lora ops C++ - # lora_outputs = torch.ops.trtllm.lora_grouped_gemm( - # x, - # lora_params['host_request_types'][:num_seqs], - # lora_ranks, - # lora_weight_pointers, - # lora_params['prompt_lens_cpu'][:num_seqs], - # self.output_hidden_sizes, - # False, # transA - # True, # transB - # max([r.max() for r in lora_ranks]), - # 0, - # True, - # ) - # if isinstance(lora_outputs, torch.Tensor): - # return lora_outputs - # else: - # # For multiple LoRA modules, some might not be executed in grouped gemm. - # # For those modules not executed, we create zero tensors with matching dimensions. - # # Finally we concatenate all tensors (both LoRA outputs and zero tensors) in order. - # lora_output = [] - # for module_idx in self.lora_module_types: - # if int(module_idx) in active_lora_module_ids: - # lora_output.append(lora_outputs.pop(0)) - # else: - # lora_output.append( - # torch.zeros(list(x.shape[:-1]) + [ - # self.output_hidden_sizes[ - # self.lora_module_types.index( - # module_idx)] - # ], - # dtype=x.dtype, - # device=x.device)) - - # lora_output = torch.cat(lora_output, dim=-1) - # return lora_output - else: return None diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py index 4ebf769069..3998127a03 100644 --- a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_attention_pytorch_flow_vs_trt.py @@ -12,11 +12,13 @@ from tensorrt_llm._torch.attention_backend.utils import get_attention_backend from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_llama import LlamaAttention +# LoRA Imports from tensorrt_llm._torch.peft.lora.layer import LoraModuleType from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm._utils import str_dtype_to_torch from tensorrt_llm.bindings.executor import KvCacheConfig -from tensorrt_llm.layers import Attention +from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.layers import AttentionParams, KeyValueCacheParams from tensorrt_llm.layers.lora import Lora, LoraParams from tensorrt_llm.mapping import Mapping @@ -27,94 +29,34 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase): def setUpClass(cls): cls.batch_size = 1 cls.seq_len = 16 - cls.hidden_size = 64 cls.head_num = 1 - cls.num_hidden_layers = 1 + cls.head_size = 64 + cls.hidden_size = cls.head_num * cls.head_size cls.dtype = 'float16' cls.torch_dtype = str_dtype_to_torch(cls.dtype) cls.device = torch.device('cuda') + cls.pos_emb_type = PositionEmbeddingType.learned_absolute + cls.causal_mask = True - # KV cache parameters - cls.num_blocks = 4 - cls.tokens_per_block = 32 + def _create_lora_params(self, ): + lora_ranks_list = [8 for _ in range(self.batch_size)] - cls.llama_config = LlamaConfig(hidden_size=cls.hidden_size, - num_attention_heads=cls.head_num, - num_hidden_layers=cls.num_hidden_layers, - intermediate_size=256, - max_position_embeddings=512, - rms_norm_eps=1e-5, - vocab_size=32000, - num_key_value_heads=cls.head_num, - torch_dtype=cls.torch_dtype) - - # Create KV cache manager - head_dim = cls.llama_config.hidden_size // cls.llama_config.num_attention_heads - mapping = Mapping(world_size=1, tp_size=1, rank=0) - kv_cache_config = KvCacheConfig(max_tokens=cls.num_blocks * - cls.tokens_per_block) - cls.kv_cache_manager = KVCacheManager( - kv_cache_config=kv_cache_config, - kv_cache_type=tensorrt_llm.bindings.internal.batch_manager. - CacheType.SELF, - num_layers=cls.llama_config.num_hidden_layers, - num_kv_heads=cls.llama_config.num_key_value_heads, - head_dim=head_dim, - tokens_per_block=cls.tokens_per_block, - max_seq_len=cls.num_blocks * cls.tokens_per_block, - max_batch_size=cls.batch_size, - mapping=mapping, - dtype=tensorrt_llm.bindings.DataType.HALF) - - @classmethod - def tearDownClass(cls): - cls.kv_cache_manager.shutdown() - - def _create_attention_inputs(self): - hidden_states = torch.empty( - size=[self.batch_size, self.seq_len, self.hidden_size], - dtype=self.torch_dtype, - device='cuda') - hidden_states.normal_(0.0, 0.02) - - # Create weights - q_weight = torch.empty(size=[self.hidden_size, self.hidden_size], - dtype=self.torch_dtype) - torch.nn.init.xavier_uniform_(q_weight) - - # Set K and V and O weights to identity matrix - eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype) - qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1) - out_weight = eye_weight - - return hidden_states, qkv_weight, out_weight - - def _create_lora_params(self): - lora_ranks_list = [8] - - host_context_lengths = torch.Tensor( - [self.seq_len for _ in range(self.batch_size)]).to(torch.int32) lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32) - host_request_types = torch.zeros_like(host_context_lengths, - device='cpu').int() lora_weight_ins = [ - torch.randn(self.hidden_size, lora_rank, device=self.device).to( + torch.randn(self.hidden_size, lora_rank, device="cuda").to( self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list ] lora_weight_outs = [ - torch.randn(lora_rank, self.hidden_size, device=self.device).to( + torch.randn(lora_rank, self.hidden_size, device="cuda").to( self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list ] - lora_weight_ins = [ - tmp.transpose(1, 0).contiguous() for tmp in lora_weight_ins - ] + lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins] lora_weight_outs = [ tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs ] - # Create weight pointers for TensorRT lora_weights_pointers = [] for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs): lora_weights_pointers.append(in_ptr.data_ptr()) @@ -125,88 +67,164 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase): return { 'lora_ranks': lora_ranks, - 'host_context_lengths': host_context_lengths, - 'host_request_types': host_request_types, 'lora_weights_pointers': lora_weights_pointers, 'lora_weight_ins': lora_weight_ins, 'lora_weight_outs': lora_weight_outs } - def _setup_attention_module(self, qkv_weight, out_weight): - """Set up the attention module with weights.""" - model_config = ModelConfig(pretrained_config=self.llama_config, - attn_backend="VANILLA") - layer_idx = 0 - attention_module = LlamaAttention(model_config, layer_idx=layer_idx).to( - self.device).to(self.torch_dtype) + def test_lora_attention(self): - # Set weights - attention_module.qkv_proj.weight.data = torch.from_numpy( - np.ascontiguousarray(qkv_weight.cpu().numpy().transpose( - [1, 0]))).to(self.device) - attention_module.o_proj.weight.data = torch.from_numpy( - np.ascontiguousarray(out_weight.cpu().numpy().transpose( - [1, 0]))).to(self.device) + mean = 0.0 + std_dev = 0.02 if self.dtype == "float32" else 0.005 - return attention_module, model_config + hidden_states = torch.concat([ + torch.empty(size=[self.seq_len, self.hidden_size], + dtype=self.torch_dtype, + device=self.device).normal_(mean, std_dev) + for _ in range(self.batch_size) + ]) - def _create_attention_metadata(self, model_config): - sequence_lengths = [self.seq_len] - past_seen_tokens = [0] - request_ids = [0] - token_nums = [self.seq_len] - prompt_lens = token_nums + context_lengths = torch.full([self.batch_size], + self.seq_len, + dtype=torch.int32, + device=self.device) - self.kv_cache_manager.add_dummy_requests(request_ids, token_nums) + # Plugin specific setup - only generate 1 step + max_seq_len = self.seq_len + 1 - metadata_cls = get_attention_backend(model_config.attn_backend).Metadata - return metadata_cls( - seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32), - num_contexts=len(sequence_lengths), - kv_cache_params=KVCacheParams( - use_cache=True, - num_cached_tokens_per_seq=past_seen_tokens, - ), - kv_cache_manager=self.kv_cache_manager, - request_ids=request_ids, - prompt_lens=prompt_lens, - max_num_requests=self.batch_size, - max_num_tokens=self.batch_size * self.seq_len, - ) + # zero means "valid" token, one means invalid. + host_past_key_value_lengths = torch.tensor([0] * self.batch_size, + dtype=torch.int32) + + # the max kv cache length for each layer. single tensor since we only have 1 layer here. + host_max_attention_window_sizes = torch.tensor([max_seq_len], + dtype=torch.int32) + host_sink_token_length = torch.tensor([0], dtype=torch.int32) + + sequence_length = torch.full([self.batch_size], + self.seq_len, + dtype=torch.int32, + device=self.device) + + # even in the the context phase, kv cache tensors can not be empty tensor for plugin, the actual shape info + # otherwise, there will be cublas execution error. + # are passed to plugin by the `sequence_length` tensor + kv_shape = (self.batch_size, 2, self.head_num, max_seq_len, + self.head_size) + past_key_value = torch.randn(kv_shape, + dtype=self.torch_dtype, + device=self.device) + cache_indirection = torch.full(( + self.batch_size, + 1, + max_seq_len, + ), + 0, + dtype=torch.int32, + device=self.device) + + host_request_types = torch.tensor([0] * self.batch_size, + dtype=torch.int32, + device='cpu') + + perf_knob_tensor_size = 16 + host_runtime_perf_knobs_tensor = torch.tensor([-1] * + perf_knob_tensor_size, + dtype=torch.int64, + device='cpu') + host_context_progress = torch.tensor([0], + dtype=torch.int64, + device='cpu') + + host_context_lengths = torch.Tensor( + [self.seq_len for _ in range(self.batch_size)]).to(torch.int32) + + q_weight = torch.empty(size=[self.hidden_size, self.hidden_size], + dtype=self.torch_dtype) + torch.nn.init.xavier_uniform_(q_weight) + + # The initialization here is chosen to minimize computation after the + # QKV BMMs in order to reduce the amount of differences from FP accumulation. + # We set K and V weights to the identity matrix so that the input is copied + # without doing any accumulation. Additionally, we set the output projection + # to the identity for the same reason. + # The main purpose of these tests is to check the QK^T BMM + Softmax + SV BMM for LoRA. + eye_weight = torch.eye(self.hidden_size, dtype=self.torch_dtype) + qkv_weight = torch.cat([q_weight, eye_weight, eye_weight], dim=-1) + + out_weight = eye_weight + + lora_params = self._create_lora_params() - def _setup_trt_network(self, hidden_states, lora_params, attention_module): builder = tensorrt_llm.Builder() net = builder.create_network() net.plugin_config.to_legacy_setting() - net.plugin_config.lora_plugin = self.dtype - + net.plugin_config.gpt_attention_plugin = self.dtype # for ragged input we use this plugin with remove_input_padding + net.plugin_config.remove_input_padding = True + net.plugin_config.lora_plugin = "float16" with tensorrt_llm.net_guard(net): - # Create input tensor trt_hidden_states = Tensor(name='hidden_states', shape=hidden_states.shape, dtype=tensorrt_llm.str_dtype_to_trt( self.dtype)) - - # Create LoRA tensors - host_request_types_tensor = Tensor( - name='host_request_types', - shape=[lora_params['host_request_types'].shape[0]], + context_lengths_tensor = Tensor( + name='context_lengths', + shape=context_lengths.shape, dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = Tensor( name='host_context_lengths', - shape=[lora_params['host_context_lengths'].shape[0]], + shape=[host_context_lengths.shape[0]], dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_ranks_tensor = Tensor( name='lora_ranks', - shape=[lora_params['lora_ranks'].shape[0]], + shape=(lora_params['lora_ranks'].shape[0], ), dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_weights_pointers_tensor = Tensor( name='lora_weights_pointers', shape=lora_params['lora_weights_pointers'].shape, dtype=tensorrt_llm.str_dtype_to_trt('int64')) - # Create LoRA parameters - lora_params = LoraParams( + host_request_types_tensor = Tensor( + name='host_request_types', + shape=host_request_types.shape, + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + past_key_value_tensor = Tensor(name='past_key_value', + shape=tuple(past_key_value.shape), + dtype=tensorrt_llm.str_dtype_to_trt( + self.dtype)) + sequence_length_tensor = Tensor( + name='sequence_length', + shape=tuple(sequence_length.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_past_key_value_lengths_tensor = Tensor( + name='host_past_key_value_lengths', + shape=tuple(host_past_key_value_lengths.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_max_attention_window_sizes_tensor = Tensor( + name='host_max_attention_window_sizes', + shape=tuple(host_max_attention_window_sizes.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_sink_token_length_tensor = Tensor( + name='host_sink_token_length', + shape=tuple(host_sink_token_length.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + cache_indirection_tensor = Tensor( + name='cache_indirection', + shape=tuple(cache_indirection.shape), + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_runtime_perf_knobs = Tensor( + name='host_runtime_perf_knobs', + shape=[16], + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + host_context_progress_tensor = Tensor( + name='host_context_progress', + shape=[1], + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + + lora_layer_params = LoraParams( lora_ranks=[{ "attn_q_lora_ranks": lora_ranks_tensor, "attn_k_lora_ranks": lora_ranks_tensor, @@ -224,15 +242,17 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase): lora_weights_pointers_tensor, }], host_context_lengths=host_context_lengths_tensor, - host_request_types=host_request_types_tensor) + host_request_types=host_request_types_tensor, + ) - attn_layer = Attention( + attn_layer = tensorrt_llm.layers.Attention( local_layer_idx=0, - hidden_size=hidden_states.shape[-1], - num_attention_heads=1, - max_position_embeddings=hidden_states.shape[1], + hidden_size=self.hidden_size, + num_attention_heads=self.head_num, + max_position_embeddings=self.seq_len, attention_mask_type=tensorrt_llm.layers.AttentionMaskType. causal, + position_embedding_type=self.pos_emb_type, bias=False) attn_layer.qkv_lora = Lora( @@ -254,29 +274,62 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase): max_low_rank=8, ) - # Set attention layer weights - attn_layer.qkv.weight.value = attention_module.qkv_proj.weight.data - attn_layer.dense.weight.value = attention_module.o_proj.weight.data + attn_layer.qkv.weight.value = np.ascontiguousarray( + qkv_weight.cpu().numpy().transpose([1, 0])) + attn_layer.dense.weight.value = np.ascontiguousarray( + out_weight.cpu().numpy().transpose([1, 0])) - output = attn_layer(hidden_states=trt_hidden_states, - lora_layer_params=lora_params) + output, present_key_value = attn_layer( + trt_hidden_states, + use_cache=True, + lora_layer_params= + lora_layer_params, # Always use cache for plugin path in this test + kv_cache_params=KeyValueCacheParams( + past_key_value=[past_key_value_tensor], + host_past_key_value_lengths= + host_past_key_value_lengths_tensor, + host_max_attention_window_sizes= + host_max_attention_window_sizes_tensor, + host_sink_token_length=host_sink_token_length_tensor, + cache_indirection=cache_indirection_tensor), + attention_params=AttentionParams( + sequence_length=sequence_length_tensor, + context_lengths=context_lengths_tensor, + host_request_types=host_request_types_tensor, + max_context_length=self.seq_len, + host_runtime_perf_knobs=host_runtime_perf_knobs, + host_context_progress=host_context_progress_tensor, + host_context_lengths=host_context_lengths_tensor, + )) + + assert isinstance(output, Tensor) output.mark_output('output', tensorrt_llm.str_dtype_to_trt(self.dtype)) + present_key_value.mark_output( + 'present_key_value', tensorrt_llm.str_dtype_to_trt(self.dtype)) - return builder, net - - def _run_trt_inference(self, builder, net, hidden_states, lora_params): - builder_config = builder.create_builder_config(name='attention', + builder_config = builder.create_builder_config(name='attention_plugin', precision=self.dtype) + engine_buffer = builder.build_engine(net, builder_config) session = tensorrt_llm.runtime.Session.from_serialized_engine( engine_buffer) stream = torch.cuda.current_stream().cuda_stream + inputs = { 'hidden_states': hidden_states, - 'host_request_types': lora_params['host_request_types'], - 'host_context_lengths': lora_params['host_context_lengths'], + 'past_key_value': past_key_value, + 'sequence_length': sequence_length, + 'host_past_key_value_lengths': host_past_key_value_lengths, + 'host_max_attention_window_sizes': host_max_attention_window_sizes, + 'host_sink_token_length': host_sink_token_length, + 'context_lengths': context_lengths, + 'host_request_types': host_request_types, + 'cache_indirection': cache_indirection, + 'host_runtime_perf_knobs': host_runtime_perf_knobs_tensor, + 'host_context_progress': host_context_progress, + 'host_context_lengths': host_context_lengths, 'lora_ranks': lora_params['lora_ranks'], 'lora_weights_pointers': lora_params['lora_weights_pointers'], } @@ -286,108 +339,128 @@ class TestLoraAttentionPytorchFlowVsTRT(unittest.TestCase): torch.empty(hidden_states.shape, dtype=tensorrt_llm._utils.str_dtype_to_torch( self.dtype), - device='cuda'), + device=self.device), + 'present_key_value': + past_key_value, } session.run(inputs=inputs, outputs=outputs, stream=stream) torch.cuda.synchronize() - return outputs['output'].squeeze(0) + # Pytorch flow + llama_config = LlamaConfig(hidden_size=self.hidden_size, + num_attention_heads=self.head_num, + num_hidden_layers=1, + intermediate_size=256, + max_position_embeddings=512, + rms_norm_eps=1e-5, + vocab_size=32000, + num_key_value_heads=self.head_num, + torch_dtype=self.torch_dtype) - def test_attention_with_lora(self): - hidden_states, qkv_weight, out_weight = self._create_attention_inputs() + mapping = Mapping(world_size=1, tp_size=1, rank=0) + kv_cache_config = KvCacheConfig(max_tokens=max_seq_len) + head_dim = llama_config.hidden_size // llama_config.num_attention_heads + kv_cache_manager = KVCacheManager( + kv_cache_config=kv_cache_config, + kv_cache_type=tensorrt_llm.bindings.internal.batch_manager. + CacheType.SELF, + num_layers=llama_config.num_hidden_layers, + num_kv_heads=llama_config.num_key_value_heads, + head_dim=head_dim, + tokens_per_block=128, + max_seq_len=max_seq_len, + max_batch_size=self.batch_size, + mapping=mapping, + dtype=tensorrt_llm.bindings.DataType.HALF) - lora_params = self._create_lora_params() + model_config = ModelConfig(pretrained_config=llama_config, + attn_backend="VANILLA") + attention_module = LlamaAttention(model_config, layer_idx=0).to( + self.device).to(self.torch_dtype) - attention_module, model_config = self._setup_attention_module( - qkv_weight, out_weight) + attention_module.qkv_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(qkv_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) + attention_module.o_proj.weight.data = torch.from_numpy( + np.ascontiguousarray(out_weight.cpu().numpy().transpose( + [1, 0]))).to(self.device) - attn_metadata = self._create_attention_metadata(model_config) - builder, net = self._setup_trt_network(hidden_states, lora_params, - attention_module) - trt_output = self._run_trt_inference(builder, net, hidden_states, - lora_params) + request_ids = [0] - lora_params_pytorchflow = { + kv_cache_manager.add_dummy_requests(request_ids=request_ids, + token_nums=[self.seq_len]) + sequence_lengths = [self.seq_len] + past_seen_tokens = [0] + metadata_cls = get_attention_backend(model_config.attn_backend).Metadata + attn_metadata = metadata_cls( + seq_lens=torch.tensor(sequence_lengths, dtype=torch.int32), + num_contexts=len(sequence_lengths), + kv_cache_params=KVCacheParams( + use_cache=True, + num_cached_tokens_per_seq=past_seen_tokens, + ), + kv_cache_manager=kv_cache_manager, + request_ids=request_ids, + prompt_lens=sequence_lengths, + max_num_requests=self.batch_size, + max_num_tokens=self.batch_size * self.seq_len, + ) + + + lora_params_pytorch_flow = { 'num_seqs': self.batch_size, - 'host_request_types': torch.zeros(self.batch_size, - dtype=torch.int32), - 'prompt_lens_cpu': torch.tensor([self.seq_len] * self.batch_size), + 'host_request_types':host_request_types, + 'prompt_lens_cpu': host_context_lengths, + 'remove_input_padding': True, 0: { # layer_idx LoraModuleType.ATTENTION_Q: { # Module type 'adapter_size': - torch.tensor([8]), + lora_params['lora_ranks'], 'weight_pointers': - torch.tensor([[ - lora_params['lora_weight_outs'][0].data_ptr(), - lora_params['lora_weight_ins'][0].data_ptr() - ]]), + lora_params['lora_weights_pointers'], 'is_dora': False, - 'weight_tensors': [ - lora_params['lora_weight_outs'][0], - lora_params['lora_weight_ins'][0] - ] }, LoraModuleType.ATTENTION_K: { 'adapter_size': - torch.tensor([8]), - 'weight_pointers': - torch.tensor([[ - lora_params['lora_weight_outs'][0].data_ptr(), - lora_params['lora_weight_ins'][0].data_ptr() - ]]), + lora_params['lora_ranks'], + 'weight_pointers': lora_params['lora_weights_pointers'], 'is_dora': False, - 'weight_tensors': [ - lora_params['lora_weight_outs'][0], - lora_params['lora_weight_ins'][0] - ] }, LoraModuleType.ATTENTION_V: { 'adapter_size': - torch.tensor([8]), + lora_params['lora_ranks'], 'weight_pointers': - torch.tensor([[ - lora_params['lora_weight_outs'][0].data_ptr(), - lora_params['lora_weight_ins'][0].data_ptr() - ]]), + lora_params['lora_weights_pointers'], 'is_dora': False, - 'weight_tensors': [ - lora_params['lora_weight_outs'][0], - lora_params['lora_weight_ins'][0] - ] }, LoraModuleType.ATTENTION_DENSE: { 'adapter_size': - torch.tensor([8]), + lora_params['lora_ranks'], 'weight_pointers': - torch.tensor([[ - lora_params['lora_weight_outs'][0].data_ptr(), - lora_params['lora_weight_ins'][0].data_ptr() - ]]), + lora_params['lora_weights_pointers'], 'is_dora': False, - 'weight_tensors': [ - lora_params['lora_weight_outs'][0], - lora_params['lora_weight_ins'][0] - ] } } } with torch.inference_mode(): attn_metadata.prepare() - hidden_states_pytorchflow = hidden_states.squeeze(0) - pytorchflow_output = attention_module( + + pytorch_flow_output = attention_module( position_ids=None, - hidden_states=hidden_states_pytorchflow, + hidden_states=hidden_states, attn_metadata=attn_metadata, attention_mask=PredefinedAttentionMask.CAUSAL, - lora_params=lora_params_pytorchflow) + lora_params=lora_params_pytorch_flow) - torch.testing.assert_close(pytorchflow_output, + trt_output = outputs['output'] + + torch.testing.assert_close(pytorch_flow_output, trt_output, atol=2e-3, rtol=0) diff --git a/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_lora_op.py b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_lora_op.py new file mode 100644 index 0000000000..80478d1bdc --- /dev/null +++ b/tests/unittest/_torch/modules/tests_lora_modules/test_lora_plugin_vs_lora_op.py @@ -0,0 +1,180 @@ +import os +import sys +import unittest + +import torch + +import tensorrt_llm +from tensorrt_llm import Tensor + +sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) +from utils.util import create_session, run_session + + +class TestLoraPluginVsLayer(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('info') + torch.random.manual_seed(0) + self.dtype = 'float16' + self.torch_dtype = torch.float16 + self.device = 'cuda' + self.batch_size = 4 + self.seq_len = 8 + self.hidden_size = 1024 + self.lora_rank = 8 + self.is_remove_input_padding = False + self.weight_index = 0 + self.transA = False + self.transB = True + + def _create_input_tensors(self, batch_size, seq_len, hidden_size, + lora_ranks_list): + input_tensor = torch.randn(batch_size, + seq_len, + hidden_size, + dtype=self.torch_dtype, + device=self.device) * 0.1 + + lora_weight_ins = [ + torch.randn(hidden_size, lora_rank, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + lora_weight_outs = [ + torch.randn(lora_rank, hidden_size, device=self.device).to( + self.torch_dtype) * 0.1 for lora_rank in lora_ranks_list + ] + + lora_weight_ins = [tmp.contiguous() for tmp in lora_weight_ins] + lora_weight_outs = [ + tmp.transpose(1, 0).contiguous() for tmp in lora_weight_outs + ] + + # Create LoRA weight pointers + lora_weights_pointers = [] + for in_ptr, out_ptr in zip(lora_weight_ins, lora_weight_outs): + lora_weights_pointers.append(in_ptr.data_ptr()) + lora_weights_pointers.append(out_ptr.data_ptr()) + # null dora scale + lora_weights_pointers.append(0) + + lora_weights_pointers = torch.LongTensor(lora_weights_pointers).to( + torch.int64).reshape([batch_size, 3]) + + # Create other tensors + host_context_lengths = torch.Tensor( + [seq_len for _ in range(batch_size)]).to(torch.int32) + lora_ranks = torch.Tensor(lora_ranks_list).to(torch.int32) + host_request_types = torch.zeros_like(host_context_lengths, + device='cpu').int() + + return { + 'input_tensor': input_tensor, + 'lora_weight_ins': lora_weight_ins, + 'lora_weight_outs': lora_weight_outs, + 'lora_weights_pointers': lora_weights_pointers, + 'host_context_lengths': host_context_lengths, + 'lora_ranks': lora_ranks, + 'host_request_types': host_request_types, + 'batch_size': batch_size, + 'seq_len': seq_len, + 'hidden_size': hidden_size, + 'max_lora_rank': max(max(lora_ranks_list), 8) + } + + def _create_lora_plugin_session(self, tensors): + # Construct TensorRT network + builder = tensorrt_llm.Builder() + network = builder.create_network() + network.plugin_config.set_lora_plugin(self.dtype) + network.plugin_config.remove_input_padding = self.is_remove_input_padding + + with tensorrt_llm.net_guard(network): + input_tensor = Tensor(name='input_tensor', + shape=[ + tensors['batch_size'], tensors['seq_len'], + tensors['hidden_size'] + ], + dtype=tensorrt_llm.str_dtype_to_trt( + self.dtype)) + host_request_types_tensor = Tensor( + name='host_request_types', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + host_context_lengths_tensor = Tensor( + name='host_context_lengths', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_ranks_tensor = Tensor( + name='lora_ranks', + shape=[tensors['batch_size']], + dtype=tensorrt_llm.str_dtype_to_trt('int32')) + lora_weights_pointers_tensor = Tensor( + name='lora_weights_pointers', + shape=[tensors['batch_size'], 3], + dtype=tensorrt_llm.str_dtype_to_trt('int64')) + + output = tensorrt_llm.functional.lora_plugin( + input_tensor, + tensors['hidden_size'], + [tensors['hidden_size']], + host_request_types_tensor, + self.transA, # transA + self.transB, # transB + host_context_lengths_tensor, + tensors['max_lora_rank'], + [lora_ranks_tensor], + [lora_weights_pointers_tensor], + weight_index=self.weight_index, + ) + output.mark_output('output') + + return create_session(builder, network, precision=self.dtype) + + def _run_lora_grouped_gemm(self, tensors): + """Run the lora_grouped_gemm operation directly""" + # Prepare parameters for lora_grouped_gemm + x = tensors['input_tensor'] + host_request_types = tensors[ + 'host_request_types'][:tensors['batch_size']] + lora_ranks = tensors['lora_ranks'] + lora_weight_pointers = tensors['lora_weights_pointers'] + prompt_lens_cpu = tensors[ + 'host_context_lengths'][:tensors['batch_size']] + output_hidden_sizes = [tensors['hidden_size']] + transA = self.transA + transB = self.transB + max_rank = max([r.item() for r in lora_ranks]) + weight_index = self.weight_index + is_remove_input_padding = self.is_remove_input_padding + + lora_outputs = torch.ops.trtllm.lora_grouped_gemm( + x, host_request_types, [lora_ranks], [lora_weight_pointers], + prompt_lens_cpu, output_hidden_sizes, transA, transB, max_rank, + weight_index, is_remove_input_padding) + + return lora_outputs[0] + + def test_lora_plugin_vs_lora_op(self): + lora_ranks_list = [self.lora_rank] * self.batch_size + + tensors = self._create_input_tensors(self.batch_size, self.seq_len, + self.hidden_size, lora_ranks_list) + + session = self._create_lora_plugin_session(tensors) + inputs = { + 'input_tensor': tensors['input_tensor'], + 'host_request_types': tensors['host_request_types'], + 'host_context_lengths': tensors['host_context_lengths'], + 'lora_ranks': tensors['lora_ranks'], + 'lora_weights_pointers': tensors['lora_weights_pointers'], + } + outputs = run_session(session, inputs) + torch.cuda.synchronize() + + lora_outputs = self._run_lora_grouped_gemm(tensors) + + torch.testing.assert_close(outputs['output'], + lora_outputs, + atol=5e-3, + rtol=0.3)