From fac47e2826ba4ffb6deb5d09eee081d221d0a66f Mon Sep 17 00:00:00 2001 From: amitz-nv <203509407+amitz-nv@users.noreply.github.com> Date: Sun, 12 Oct 2025 22:29:52 +0300 Subject: [PATCH] [https://nvbugs/5510879][fix] Fix pytorch & TRT-python flows fused LoRA adapter modules weight split with TP>1 (#8063) Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com> --- .../_torch/pyexecutor/resource_manager.py | 29 ++-- tensorrt_llm/_utils.py | 6 +- tensorrt_llm/executor/base_worker.py | 9 +- tensorrt_llm/lora_helper.py | 1 + tensorrt_llm/lora_manager.py | 125 +++++++++++------- tensorrt_llm/runtime/enc_dec_model_runner.py | 12 +- tensorrt_llm/runtime/generation.py | 38 +++++- tensorrt_llm/runtime/model_runner.py | 8 +- tensorrt_llm/runtime/model_runner_cpp.py | 12 +- tests/unittest/llmapi/lora_test_utils.py | 56 +++++++- .../llmapi/test_llm_multi_gpu_pytorch.py | 11 +- 11 files changed, 229 insertions(+), 78 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 6e4f4a9849..bc2804584a 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -13,6 +13,7 @@ from tensorrt_llm._utils import mpi_disabled from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.lora_helper import LoraConfig from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig +from tensorrt_llm.runtime import ModelConfig as ModelConfigPython from tensorrt_llm.sampling_params import SamplingParams from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range @@ -32,7 +33,7 @@ BufferManagerCpp = tensorrt_llm.bindings.internal.runtime.BufferManager KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType -ModelConfig = tensorrt_llm.bindings.ModelConfig +ModelConfigCpp = tensorrt_llm.bindings.ModelConfig DataType = tensorrt_llm.bindings.DataType KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager RequestList = list[LlmRequest] @@ -160,7 +161,7 @@ class KVCacheManager(BaseResourceManager): spec_config: Optional["DecodingBaseConfig"] = None, layer_mask: Optional[List[bool]] = None, max_num_tokens: int = 8192, - model_config: Optional[ModelConfig] = None, + model_config: Optional[ModelConfigCpp] = None, max_beam_width: int = 1, is_draft: bool = False, kv_connector_manager: Optional[KvCacheConnectorManager] = None, @@ -371,7 +372,7 @@ class KVCacheManager(BaseResourceManager): @classmethod def from_model_config(cls, - model_config: ModelConfig, + model_config: ModelConfigCpp, kv_cache_config: KvCacheConfigCpp, mapping: Mapping, kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF, @@ -772,7 +773,7 @@ class KVCacheManager(BaseResourceManager): window_size_to_layers: Dict[int, List[int]], max_attention_window_vec: List[int], kv_cache_config: KvCacheConfigCpp, - model_config: ModelConfig, + model_config: ModelConfigCpp, pool_memory_bytes: int, kv_factor: int, dtype: DataType, @@ -887,7 +888,7 @@ class KVCacheManager(BaseResourceManager): def calculate_max_num_blocks_from_cpp( self, kv_cache_config: KvCacheConfigCpp, - model_config: ModelConfig, + model_config: ModelConfigCpp, extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]: """ This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks. @@ -1133,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager): def __init__(self, peft_cache_config: PeftCacheConfig, lora_config: LoraConfig, - model_config: ModelConfig, + model_config: ModelConfigCpp, world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb @@ -1169,7 +1170,20 @@ class PeftCacheManager(BaseResourceManager): lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size, binding_to_str_dtype(model_config.data_type), lora_config.swap_gate_up_proj_lora_b_weight) - self._lora_manager = LoraManager() + mapping = Mapping( + world_size=world_config.size, + rank=world_config.rank, + tp_size=world_config.tensor_parallelism, + pp_size=world_config.pipeline_parallelism, + gpus_per_node=world_config.gpus_per_node, + ) + self._lora_manager = LoraManager( + mapping=mapping, + model_config=ModelConfigPython.from_model_config_cpp(model_config), + cpp_peft_cache_manager=self.impl) + + def get_lora_manager(self) -> LoraManager: + return self._lora_manager def add_request_peft(self, request: LlmRequest): if request.lora_task_id is not None: @@ -1183,7 +1197,6 @@ class PeftCacheManager(BaseResourceManager): self._lora_manager.load_from_ckpt( [request.py_lora_path], model_config=self._lora_model_config, - runtime_mapping=None, uids=[request.lora_task_id], ckpt_source=self._lora_config.lora_ckpt_source) request.lora_weights = self._lora_manager.cpp_lora_weights[ diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index baef7e79a6..4c696511dc 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -42,7 +42,7 @@ import torch import tensorrt as trt # isort: on -from tensorrt_llm.bindings import DataType, GptJsonConfig +from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE from tensorrt_llm.logger import logger @@ -198,6 +198,10 @@ _binding_dtype_bits = { } +def binding_layer_type_to_str(layer_type: LayerType) -> str: + return layer_type.name.lower() + + def binding_to_str_dtype(binding_dtype) -> str: ret = _binding_to_str_dtype.get(binding_dtype) assert ret is not None, f'Unsupported binding dtype: {binding_dtype}' diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 498f976ecc..f2655cafb4 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -205,7 +205,10 @@ class BaseWorker(GenerationExecutor): # point in the TRT flow is currently not supported (it's at the CPP # Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA # optimization is not available in TRT-python flow. - self._lora_manager = LoraManager(cpp_peft_cache_manager=None) + self._lora_manager = LoraManager( + mapping=engine_config.pretrained_config.mapping, + model_config=self._runtime_model_config, + cpp_peft_cache_manager=None) if engine_config.build_config.max_prompt_embedding_table_size > 0: self._prompt_adapter_manager = PromptAdapterManager() @@ -216,8 +219,7 @@ class BaseWorker(GenerationExecutor): ResourceManagerType peft_cache_manager = self.engine.resource_manager.resource_managers.get( ResourceManagerType.PEFT_CACHE_MANAGER) - self._lora_manager = LoraManager( - cpp_peft_cache_manager=peft_cache_manager.impl) + self._lora_manager = peft_cache_manager.get_lora_manager() lora_model_config = self.engine.model_engine.lora_model_config assert lora_model_config is not None self._lora_model_config = lora_model_config @@ -302,7 +304,6 @@ class BaseWorker(GenerationExecutor): [lora_request.path], model_config=self._runtime_model_config if self._runtime_model_config is not None else self._lora_model_config, - runtime_mapping=None, uids=[adapter_id], ckpt_source=lora_request.ckpt_source) return adapter_id in newly_loaded_uids diff --git a/tensorrt_llm/lora_helper.py b/tensorrt_llm/lora_helper.py index 719df51079..b9c1423272 100644 --- a/tensorrt_llm/lora_helper.py +++ b/tensorrt_llm/lora_helper.py @@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules(): "attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", + "attn_qkv": "qkv_proj", "attn_dense": "o_proj", "mlp_h_to_4h": "gate_proj", "mlp_4h_to_h": "down_proj", diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index c7dc6f28bc..4fe0d0b44c 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -1,4 +1,5 @@ import io +import itertools import json import logging import re @@ -660,11 +661,17 @@ class LoraManager(object): } def __init__( - self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None + self, + *, + mapping: Mapping, + model_config: "ModelConfig", + cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None, ): """Constructor. Args: + mapping (Mapping): Parallelism related information. + model_config (ModelConfig): model configuration python class instance. cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when the adapter is already loaded in the LoRA CPU cache. @@ -704,6 +711,8 @@ class LoraManager(object): self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu self.lora_target_modules: List[str] = [] + self._mapping = mapping + self._model_config = model_config self._cpp_peft_cache_manager = cpp_peft_cache_manager def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool: @@ -730,7 +739,6 @@ class LoraManager(object): self, model_dirs_or_files: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, ckpt_source: str = "hf", ) -> List[str]: @@ -743,7 +751,6 @@ class LoraManager(object): return self.load_from_hf( model_dirs=model_dirs_or_files, model_config=model_config, - runtime_mapping=runtime_mapping, uids=uids, ) elif ckpt_source == "nemo": @@ -754,7 +761,6 @@ class LoraManager(object): return self.load_from_nemo( model_files=nemo_files, model_config=model_config, - runtime_mapping=runtime_mapping, uids=uids, ) else: @@ -764,7 +770,6 @@ class LoraManager(object): self, model_files: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, ) -> List[str]: """Returns the adapter UIDs that were loaded by this call. @@ -772,11 +777,6 @@ class LoraManager(object): Note that when an adapter was already loaded before this call, it would not be included in the returned list of UIDs. """ - if runtime_mapping is None: - runtime_mapping = Mapping() - tp_size = runtime_mapping.tp_size - tp_rank = runtime_mapping.tp_rank - if uids is None: uids = [self._generate_uid() for _ in range(len(model_files))] assert len(uids) == len(model_files) @@ -829,10 +829,6 @@ class LoraManager(object): t_in = all_lora_weights[layer_idx]["in"] t_out = all_lora_weights[layer_idx]["out"] - assert t_out.shape[0] % tp_size == 0 - t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[ - tp_rank - ].contiguous() else: t_in = None t_out = None @@ -882,7 +878,6 @@ class LoraManager(object): self, model_dirs: List[str], model_config: Union["ModelConfig", LoraModelConfig], - runtime_mapping: Optional[Mapping] = None, uids: Optional[List[str]] = None, component: Optional[str] = None, ) -> List[str]: @@ -939,11 +934,6 @@ class LoraManager(object): ... """ - if runtime_mapping is None: - runtime_mapping = Mapping() - tp_size = runtime_mapping.tp_size - tp_rank = runtime_mapping.tp_rank - if uids is None: uids = [self._generate_uid() for _ in range(len(model_dirs))] assert len(uids) == len(model_dirs) @@ -983,6 +973,70 @@ class LoraManager(object): lora_model[key] = value return lora_model + def interleave_fused_lora_weights_for_tp( + weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int] + ) -> List[torch.Tensor]: + """Interleaves fused LoRA modules weights for TP. + e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to + torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN]) + where N=TP size. + """ # noqa: D205 + assert weight.shape[rank_dim] == sum(part_sizes) + + # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv. + weight_parts = [ + weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i]) + for i in range(len(part_sizes)) + ] + for i in range(len(part_sizes)): + assert weight_parts[i].shape[rank_dim] % tp_size == 0 + + # Split each part into tp_size chunks. + # e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]] + # where N is TP size, for attn_qkv. + weight_parts_tp_weights = [ + torch.split( + weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim + ) + for i in range(len(part_sizes)) + ] + + # Interleave the parts across TP ranks and flatten the list of lists into a single list. + # e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]] + # -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv. + return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights))) + + def prepare_fused_lora_modules_for_tp( + lora_module: str, t_out: torch.Tensor, rank_dim: int + ) -> torch.Tensor: + """Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights + sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly. + + See interleave_fused_lora_weights_for_tp for more details. + """ # noqa: D205 + tp_size = self._mapping.tp_size + if tp_size == 1: + return t_out + part_sizes = [] + if lora_module == "mlp_gate_up": + assert t_out.shape[rank_dim] % 2 == 0 + half_size = t_out.shape[rank_dim] // 2 + part_sizes = [half_size, half_size] + elif lora_module == "attn_qkv": + # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already + # divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config + q_size = self._model_config.head_size * self._model_config.num_heads * tp_size + kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size + part_sizes = [q_size, kv_size, kv_size] + + if part_sizes: + interleaved_parts = interleave_fused_lora_weights_for_tp( + t_out, rank_dim, tp_size, part_sizes + ) + # Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights. + t_out = torch.cat(interleaved_parts, dim=rank_dim) + return t_out + def load_from_model_dir(uid, model_dir, hf_config): if uid not in self._cpp_lora_weights: self._cpp_lora_weights[uid] = [] # Will be converted to tensor later @@ -1060,36 +1114,9 @@ class LoraManager(object): t_mag = module_weights.get("magnitude", None) is_dora = t_mag is not None - - if lora_module in ["moe_router", "mlp_router"]: - pass - elif "moe" in lora_module and runtime_mapping.has_moe_ep(): - pass - elif lora_module in [ - "attn_dense", - "cross_attn_dense", - "mlp_4h_to_h", - "moe_4h_to_h", - ]: - # split by row - dim = 2 if has_expert_indices else 1 - assert t_in.shape[dim] % tp_size == 0 - t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[ - tp_rank - ].contiguous() - else: - # split by column - dim = 1 if has_expert_indices else 0 - assert t_out.shape[dim] % tp_size == 0 - t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[ - tp_rank - ].contiguous() - if dim == 0 and is_dora and t_mag is not None: - t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[ - tp_rank - ].contiguous() - rank_dim = 1 if has_expert_indices else 0 + t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim) + effective_rank = t_in.shape[rank_dim] t_in = t_in.cuda().contiguous() diff --git a/tensorrt_llm/runtime/enc_dec_model_runner.py b/tensorrt_llm/runtime/enc_dec_model_runner.py index f2f482a250..57ed27ae09 100644 --- a/tensorrt_llm/runtime/enc_dec_model_runner.py +++ b/tensorrt_llm/runtime/enc_dec_model_runner.py @@ -174,12 +174,14 @@ class EncDecModelRunner: # encoder lora manager setup if self.encoder_model_config.lora_plugin: - self.encoder_lora_manager = LoraManager() + self.encoder_lora_manager = LoraManager( + mapping=self.encoder_runtime_mapping, + model_config=self.encoder_model_config, + ) # TODO: this is only for bart self.encoder_lora_manager.load_from_hf( model_dirs=lora_dir, model_config=self.encoder_model_config, - runtime_mapping=self.encoder_runtime_mapping, component='encoder', ) else: @@ -197,12 +199,14 @@ class EncDecModelRunner: # decoder lora manager setup if self.decoder_model_config.lora_plugin: - self.decoder_lora_manager = LoraManager() + self.decoder_lora_manager = LoraManager( + mapping=self.decoder_runtime_mapping, + model_config=self.decoder_model_config, + ) # TODO: this is only for bart self.decoder_lora_manager.load_from_hf( model_dirs=lora_dir, model_config=self.decoder_model_config, - runtime_mapping=self.decoder_runtime_mapping, component='decoder', ) else: diff --git a/tensorrt_llm/runtime/generation.py b/tensorrt_llm/runtime/generation.py index bf6e228f76..36cdbf0aca 100755 --- a/tensorrt_llm/runtime/generation.py +++ b/tensorrt_llm/runtime/generation.py @@ -40,7 +40,8 @@ from tensorrt_llm.runtime.memory_pools.pools_kv_cache_manager import \ PoolsKVCacheManager from tensorrt_llm.runtime.redrafter_utils import * -from .._utils import (pad_vocab_size, str_dtype_to_torch, torch_to_numpy, +from .._utils import (binding_layer_type_to_str, binding_to_str_dtype, + pad_vocab_size, str_dtype_to_torch, torch_to_numpy, trt_dtype_to_torch) from ..bindings import KVCacheType, ipc_nvls_allocate, ipc_nvls_free from ..layers import LanguageAdapterConfig @@ -653,6 +654,41 @@ class ModelConfig: # language adapter language_adapter_config: Optional[LanguageAdapterConfig] = None + @classmethod + def from_model_config_cpp(cls, model_config_cpp) -> 'ModelConfig': + """Create a partially initialized ModelConfig instance from a given ModelConfig CPP binding instance. + + Note that each of these classes have fields that don't exist in the other, so the created ModelConfigPython + won't have all of its fields initialized. + """ + return cls( + max_batch_size=model_config_cpp.max_batch_size, + max_beam_width=model_config_cpp.max_beam_width, + vocab_size=model_config_cpp.vocab_size, + num_layers=model_config_cpp.num_layers(), + num_heads=model_config_cpp.num_heads, + num_kv_heads=model_config_cpp.num_kv_heads(0), + hidden_size=model_config_cpp.hidden_size, + remove_input_padding=model_config_cpp.use_packed_input, + kv_cache_type=model_config_cpp.kv_cache_type, + cross_attention=model_config_cpp.use_cross_attention, + head_size=model_config_cpp.head_size, + max_prompt_embedding_table_size=model_config_cpp. + max_prompt_embedding_table_size, + quant_mode=QuantMode(model_config_cpp.quant_mode.value), + gather_context_logits=model_config_cpp.compute_context_logits, + gather_generation_logits=model_config_cpp.compute_generation_logits, + gpt_attention_plugin=model_config_cpp.use_gpt_attention_plugin, + dtype=binding_to_str_dtype(model_config_cpp.data_type), + num_kv_heads_per_layer=model_config_cpp.num_kv_heads_per_layer, + tokens_per_block=model_config_cpp.tokens_per_block, + lora_plugin=model_config_cpp.use_lora_plugin, + layer_types=[ + binding_layer_type_to_str(lt) + for lt in model_config_cpp.layer_types + ], + ) + @dataclass class SamplingConfig: diff --git a/tensorrt_llm/runtime/model_runner.py b/tensorrt_llm/runtime/model_runner.py index ee35da3ef0..94965e66d2 100644 --- a/tensorrt_llm/runtime/model_runner.py +++ b/tensorrt_llm/runtime/model_runner.py @@ -611,11 +611,11 @@ class ModelRunner(ModelRunnerMixin): session.runtime._set_weight_streaming(gpu_weights_percent) if session.use_lora_plugin: - lora_manager = LoraManager() + lora_manager = LoraManager(mapping=runtime_mapping, + model_config=model_config) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, - runtime_mapping=runtime_mapping, ckpt_source=lora_ckpt_source) else: lora_manager = None @@ -720,11 +720,11 @@ class ModelRunner(ModelRunnerMixin): debug_mode=debug_mode, stream=stream) if session.use_lora_plugin: - lora_manager = LoraManager() + lora_manager = LoraManager(mapping=runtime_mapping, + model_config=model_config) if lora_dir is not None: lora_manager.load_from_ckpt(lora_dir, model_config=model_config, - runtime_mapping=runtime_mapping, ckpt_source=lora_ckpt_source) else: lora_manager = None diff --git a/tensorrt_llm/runtime/model_runner_cpp.py b/tensorrt_llm/runtime/model_runner_cpp.py index b701f245f6..9689526807 100644 --- a/tensorrt_llm/runtime/model_runner_cpp.py +++ b/tensorrt_llm/runtime/model_runner_cpp.py @@ -32,8 +32,9 @@ from ..builder import EngineConfig from ..layers import MropeParams from ..logger import logger from ..mapping import Mapping -from .generation import (LogitsProcessor, LoraManager, SamplingConfig, - StoppingCriteria) +from .generation import LogitsProcessor, LoraManager +from .generation import ModelConfig as ModelConfigPython +from .generation import SamplingConfig, StoppingCriteria from .model_runner import ModelRunnerMixin, _engine_config_to_model_config _bindings_dtype_to_torch_dtype_dict = { @@ -277,7 +278,11 @@ class ModelRunnerCpp(ModelRunnerMixin): engine_config = EngineConfig.from_json_file(f"{engine_dir}/config.json") if model_config.use_lora_plugin and rank == 0: - lora_manager = LoraManager() + mapping = _world_config_to_mapping(world_config) + lora_manager = LoraManager( + mapping=mapping, + model_config=ModelConfigPython.from_model_config_cpp( + model_config)) if lora_dir is None: config_lora_dir = engine_config.build_config.lora_config.lora_dir if len(config_lora_dir) > 0: @@ -292,7 +297,6 @@ class ModelRunnerCpp(ModelRunnerMixin): # For Executor, only rank 0 can enqueue requests, and should hold all lora weights lora_manager.load_from_ckpt(lora_dir, model_config=runtime_model_config, - runtime_mapping=None, ckpt_source=lora_ckpt_source) else: raise RuntimeError( diff --git a/tests/unittest/llmapi/lora_test_utils.py b/tests/unittest/llmapi/lora_test_utils.py index 58673aa069..a123df495b 100644 --- a/tests/unittest/llmapi/lora_test_utils.py +++ b/tests/unittest/llmapi/lora_test_utils.py @@ -2,7 +2,7 @@ import json import tarfile import tempfile from pathlib import Path -from typing import OrderedDict, Type +from typing import List, OrderedDict, Type import torch from utils.llm_data import llm_models_root @@ -11,10 +11,62 @@ from utils.util import duplicate_list_to_length, flatten_list, similar from tensorrt_llm import SamplingParams from tensorrt_llm.executor.request import LoRARequest from tensorrt_llm.llmapi.llm import BaseLLM +from tensorrt_llm.lora_helper import LoraConfig + +_RU_LORA_ADAPTER_PROMPTS = [ + "Назови главную площадь в центре Москвы.", + "Напиши полное предложение, описывающее, что в музее не хватает женских скульптур. Используй фразу \"не хватает\".", + "Что означает выражение \"водить за нос\"? Объясни в двух словах.", +] + + +def _generate_phi3_response_lora_fused_modules(llm_class: Type[BaseLLM], + prompts: List[str], + **extra_llm_kwargs) -> List[str]: + """Generates responses with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter. + The used LoRA adapter has fused attention QKV and fused MLP gate up proj modules. + Returns the generated texts. + """ # noqa: D205 + hf_model_dir = f"{llm_models_root()}/Phi-3/Phi-3-mini-4k-instruct" + hf_lora_dir = f"{llm_models_root()}/lora/phi/Phi-3-mini-4k-instruct-ru-lora" + + lora_req = LoRARequest("ru-lora", 0, hf_lora_dir) + sampling_params = SamplingParams(max_tokens=20) + + lora_config = LoraConfig(lora_dir=[hf_lora_dir], + max_lora_rank=16, + max_loras=2, + max_cpu_loras=2) + + lora_requests = [lora_req] * len(prompts) + with llm_class(hf_model_dir, lora_config=lora_config, + **extra_llm_kwargs) as llm: + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests) + + return [output.outputs[0].text for output in outputs] + + +def check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( + llm_class: Type[BaseLLM], **extra_llm_kwargs) -> None: + """Tests the output with LoRA requests with the Phi-3-mini-4k-instruct-ru-lora adapter with TP=2 is identical to + the output with TP=1. + That LoRA adapter has fused attention QKV and fused MLP gate up proj modules. + """ # noqa: D205 + extra_llm_kwargs["tensor_parallel_size"] = 1 + outputs_tp1 = _generate_phi3_response_lora_fused_modules( + llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) + + extra_llm_kwargs["tensor_parallel_size"] = 2 + outputs_tp2 = _generate_phi3_response_lora_fused_modules( + llm_class, _RU_LORA_ADAPTER_PROMPTS, **extra_llm_kwargs) + + assert outputs_tp1 == outputs_tp2 def check_llama_7b_multi_unique_lora_adapters_from_request( - lora_adapter_count_per_call: list[int], repeat_calls: int, + lora_adapter_count_per_call: List[int], repeat_calls: int, repeats_per_call: int, llm_class: Type[BaseLLM], **llm_kwargs): """Calls llm.generate s.t. for each C in lora_adapter_count_per_call, llm.generate is called with C requests repeated 'repeats_per_call' times, where each request is configured with a unique LoRA adapter ID. diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index b145122d17..f4fa75e7da 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -5,7 +5,7 @@ from .test_llm import tinyllama_logits_processor_test_harness, llama_model_path from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.lora_helper import LoraConfig -from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness +from .lora_test_utils import check_llama_7b_multi_lora_from_request_test_harness, check_phi3_lora_fused_modules_output_tp2_identical_to_tp1 from .test_llm_pytorch import llama_7b_lora_from_dir_test_harness from .test_llm import _test_llm_capture_request_error from utils.util import skip_ray @@ -62,6 +62,15 @@ def test_llama_7b_multi_lora_tp2(): cuda_graph_config=None) +@pytest.mark.gpu2 +def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None: + check_phi3_lora_fused_modules_output_tp2_identical_to_tp1( + LLM, + # Disable CUDA graph + # TODO: remove this once we have a proper fix for CUDA graph in LoRA + cuda_graph_config=None) + + @pytest.mark.skip(reason="https://nvbugs/5560921") @skip_ray @pytest.mark.gpu2