From 7d94c9561f0c82a23b24dadcbf7b598dbc866a6d Mon Sep 17 00:00:00 2001 From: shaharmor98 <17088876+shaharmor98@users.noreply.github.com> Date: Thu, 8 May 2025 18:45:45 +0300 Subject: [PATCH] feat: support multi lora adapters and TP (#3885) * support multi lora, tp Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com> --- examples/pytorch/quickstart_lora.py | 38 ------- tensorrt_llm/_torch/model_config.py | 10 +- tensorrt_llm/_torch/models/modeling_llama.py | 22 ++-- tensorrt_llm/_torch/models/modeling_utils.py | 30 ++++-- tensorrt_llm/_torch/modules/attention.py | 14 +-- tensorrt_llm/_torch/modules/gated_mlp.py | 50 ++++----- tensorrt_llm/_torch/modules/linear.py | 30 +++++- tensorrt_llm/_torch/modules/mlp.py | 25 ++--- tensorrt_llm/_torch/peft/lora/layer.py | 9 +- tensorrt_llm/_torch/pyexecutor/_util.py | 27 ++++- .../_torch/pyexecutor/resource_manager.py | 11 +- tensorrt_llm/llmapi/llm.py | 19 +++- tensorrt_llm/lora_manager.py | 3 +- .../_torch/modeling/test_modeling_llama.py | 37 +------ .../unittest/_torch/test_resource_manager.py | 4 - tests/unittest/llmapi/test_llm.py | 1 + .../llmapi/test_llm_multi_gpu_pytorch.py | 18 ++++ tests/unittest/llmapi/test_llm_pytorch.py | 101 ++++++++++++++++-- 18 files changed, 274 insertions(+), 175 deletions(-) delete mode 100644 examples/pytorch/quickstart_lora.py diff --git a/examples/pytorch/quickstart_lora.py b/examples/pytorch/quickstart_lora.py deleted file mode 100644 index 7feada77ab..0000000000 --- a/examples/pytorch/quickstart_lora.py +++ /dev/null @@ -1,38 +0,0 @@ -from tensorrt_llm import SamplingParams -from tensorrt_llm._torch import LLM -from tensorrt_llm.executor import LoRARequest -from tensorrt_llm.lora_manager import LoraConfig - - -def main(): - lora_config = LoraConfig(lora_dir=[ - "/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b" - ], - max_lora_rank=64) - llm = LLM( - model= - "/home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf", - lora_config=lora_config, - ) - prompts = [ - "今天天气很好,我到公园的时候,", - ] - - sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False) - lora_req_2 = LoRARequest( - "task-0", 0, - "/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b" - ) - lora_request = [lora_req_2] - - outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) - - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") - - -if __name__ == '__main__': - main() diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index f921b59557..0beb82a357 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -168,22 +168,18 @@ class ModelConfig(Generic[TConfig]): quant_config_dict=layer_quant_config, **kwargs) - def get_bindings_model_config( - self, - tensor_parallelism: int = 1, - context_parallelism: int = 1) -> "ModelConfigCpp": + def get_bindings_model_config(self) -> "ModelConfigCpp": """ This method is used to construct the bindings config for the model. Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes that an engine has been created. """ # TODO smor- this isn't robust, and currently tested for LlamaConfig only - # TODO smor- currently parallelism is not supported, set default to 1 # TODO smor- currently assuming no rnn layers, no MOE from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp num_heads = self.pretrained_config.num_attention_heads // ( - tensor_parallelism * context_parallelism) + self.mapping.tp_size * self.mapping.cp_size) model_config_cpp = ModelConfigCpp( vocab_size=self.pretrained_config.vocab_size, @@ -195,7 +191,7 @@ class ModelConfig(Generic[TConfig]): data_type=torch_dtype_to_binding( self.pretrained_config.torch_dtype)) - mlp_hidden_size = self.pretrained_config.intermediate_size // tensor_parallelism + mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size if "head_size" in self.pretrained_config: head_size = self.pretrained_config.head_size else: diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index ab8f9b4c35..6976eb2cff 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -13,6 +13,7 @@ from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, DeepseekAllReduce) from tensorrt_llm._torch.pipeline_interface import PipelineInterface from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.models.convert_utils import split_matrix_tp from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt, register_input_processor) @@ -773,13 +774,14 @@ class LlamaModel(DecoderModel): self.padding_idx = config.pad_token_id vocab_size = config.vocab_size - # TODO smor- hack - if hasattr(model_config, - 'lora_config') and model_config.lora_config is not None: + # TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution + if hasattr( + model_config, + 'lora_config') and model_config.lora_config is not None and len( + model_config.lora_config.lora_dir) == 1: from tensorrt_llm.lora_manager import HfLoraLoader lora_loader = HfLoraLoader(model_config.lora_config.lora_dir) weight = lora_loader.embed_tokens - # TODO smor - need to split tp matrix here vocab_size = lora_loader.vocab_size self.embed_tokens = Embedding( @@ -791,9 +793,17 @@ class LlamaModel(DecoderModel): gather_output=True, ) - if hasattr(model_config, - 'lora_config') and model_config.lora_config is not None: + if hasattr( + model_config, + 'lora_config') and model_config.lora_config is not None and len( + model_config.lora_config.lora_dir) == 1: with torch.no_grad(): + if model_config.mapping.tp_size > 1: + weight = split_matrix_tp( + weight, + model_config.mapping.tp_size, + model_config.mapping.tp_rank, + dim=0) # split by vocabulary dimension x = weight.to(self.embed_tokens.dtype) self.embed_tokens.weight.data.copy_(x) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index 28a7257d40..45e0b1ec27 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -11,6 +11,9 @@ from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_any_only from tqdm import tqdm +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models.convert_utils import split_matrix_tp + from ...logger import logger from ...mapping import Mapping from ...models.modeling_utils import QuantConfig @@ -240,7 +243,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller): input_ids: torch.LongTensor = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - lora_params: Optional = None, # TODO smor add type hint + lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: if (input_ids is None) ^ (inputs_embeds is not None): @@ -357,9 +360,9 @@ class DecoderModelForCausalLM(nn.Module, # TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig # will considering per layer QuantConfig in the future. - # TODO smor- hack - if hasattr(config, - 'lora_config') and config.lora_config is not None: + if hasattr(config, 'lora_config' + ) and config.lora_config is not None and len( + config.lora_config.lora_dir) == 1: from tensorrt_llm.lora_manager import HfLoraLoader lora_loader = HfLoraLoader(config.lora_config.lora_dir) weight = lora_loader.lm_head @@ -374,9 +377,16 @@ class DecoderModelForCausalLM(nn.Module, gather_output=True, ) - if hasattr(config, - 'lora_config') and config.lora_config is not None: + if hasattr(config, 'lora_config' + ) and config.lora_config is not None and len( + config.lora_config.lora_dir) == 1: with torch.no_grad(): + if config.mapping.tp_size > 1: + weight = split_matrix_tp( + weight, + config.mapping.tp_size, + config.mapping.tp_rank, + dim=0) # split by vocabulary dimension x = weight.to(self.lm_head.dtype).cuda() self.lm_head.weight.data.copy_(x) @@ -475,7 +485,7 @@ class DecoderModelForCausalLM(nn.Module, pipeline_interface: Optional[PipelineInterface] = None, return_context_logits: bool = False, spec_metadata: Optional[SpecMetadata] = None, - lora_params: Optional = None, # TODO smor add type hint + lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: if self._supports_pp and self.pp_size > 1: @@ -657,8 +667,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM], # Skip loading weights for embedding and lm_head if LoRA is enabled if hasattr(model.model_config, 'lora_config' - ) and model.model_config.lora_config is not None and ( - name == "model.embed_tokens" or name == "lm_head"): + ) and model.model_config.lora_config is not None and len( + model.model_config.lora_config.lora_dir) == 1 and ( + name == "model.embed_tokens" + or name == "lm_head"): continue # Skip if parameter belongs to a missing layer diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index f16efe8caa..b5075e81df 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -88,6 +88,9 @@ class Attention(nn.Module): quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, ) + self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], + [self.hidden_size]) + self.o_proj = Linear( tp_size * self.q_size, self.hidden_size, @@ -97,6 +100,7 @@ class Attention(nn.Module): tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), skip_create_weights_in_init=config.skip_create_weights_in_init, + lora=self.o_lora, ) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend @@ -229,13 +233,9 @@ class Attention(nn.Module): mrope_config=mrope_config) hidden_states = attn_output attn_output = self.o_proj(attn_output, - all_reduce_params=all_reduce_params) - if bool(lora_params): - attn_lora_output = self.o_lora(hidden_states, lora_params, - self.layer_idx) - if attn_lora_output is not None: - attn_output = attn_output + attn_lora_output - + all_reduce_params=all_reduce_params, + lora_params=lora_params, + layer_idx=self.layer_idx) return attn_output diff --git a/tensorrt_llm/_torch/modules/gated_mlp.py b/tensorrt_llm/_torch/modules/gated_mlp.py index d1775c4502..7fb134d804 100644 --- a/tensorrt_llm/_torch/modules/gated_mlp.py +++ b/tensorrt_llm/_torch/modules/gated_mlp.py @@ -76,6 +76,9 @@ class GatedMLP(nn.Module): reduce_output=False, skip_create_weights_in_init=config.skip_create_weights_in_init, ) + self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], + [self.hidden_size]) + self.down_proj = Linear( self.intermediate_size, self.hidden_size, @@ -86,18 +89,20 @@ class GatedMLP(nn.Module): quant_config=config.get_quant_config(), reduce_output=reduce_output, skip_create_weights_in_init=config.skip_create_weights_in_init, + lora=self.down_lora, ) # These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used, # but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora # handles them as a single fused operation. self.splitted_gate_up_lora = LoraLayer( - [LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], - [self.intermediate_size, self.intermediate_size]) - self.fused_gate_up_lora = LoraLayer([LoraModuleType.MLP_GATE_UP], - [2 * self.intermediate_size]) - self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], - [self.hidden_size]) + [LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], [ + self.intermediate_size // mapping.tp_size, + self.intermediate_size // mapping.tp_size + ]) + self.fused_gate_up_lora = LoraLayer( + [LoraModuleType.MLP_GATE_UP], + [2 * self.intermediate_size // mapping.tp_size]) def forward( self, @@ -107,33 +112,17 @@ class GatedMLP(nn.Module): lora_params: Optional[dict] = None, **kwargs, ) -> torch.Tensor: - if lora_params is not None: + if bool(lora_params): return self.forward_lora(x, all_rank_num_tokens, final_all_reduce_params, lora_params) if self.activation == F.silu: h1 = self.gate_up_proj(x) - if bool(lora_params): - assert self.layer_idx is not None, "layer_idx is required for lora" - h1_lora = self.splitted_gate_up_lora(x, lora_params, - self.layer_idx) - if h1_lora is not None: - h1 = h1 + h1_lora - - h1_lora = self.fused_gate_up_lora(x, lora_params, - self.layer_idx) - - if h1_lora is not None: - h1 = h1 + h1_lora h2 = swiglu(h1) output = self.down_proj(h2, - all_reduce_params=final_all_reduce_params) - if bool(lora_params): - output_lora = self.down_lora(h2, lora_params, self.layer_idx) - if output_lora is not None: - output = output + output_lora - + all_reduce_params=final_all_reduce_params, + layer_idx=self.layer_idx) return output else: raise NotImplementedError( @@ -154,19 +143,18 @@ class GatedMLP(nn.Module): h1 = self.gate_up_proj(x) h1_lora = self.splitted_gate_up_lora(x, lora_params, self.layer_idx) + if h1_lora is not None: h1 = h1 + h1_lora h1_lora = self.fused_gate_up_lora(x, lora_params, self.layer_idx) - if h1_lora is not None: h1 = h1 + h1_lora h2 = swiglu(h1) - output = self.down_proj(h2, all_reduce_params=final_all_reduce_params) - - output_lora = self.down_lora(h2, lora_params, self.layer_idx) - if output_lora is not None: - output = output + output_lora + output = self.down_proj(h2, + all_reduce_params=final_all_reduce_params, + lora_params=lora_params, + layer_idx=self.layer_idx) return output diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index e195367120..482b3085d0 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -9,6 +9,7 @@ from torch import nn from torch.nn.parameter import Parameter import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils +from tensorrt_llm._torch.peft.lora.layer import LoraLayer from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams from tensorrt_llm.mapping import Mapping @@ -157,6 +158,7 @@ class Linear(nn.Module): reduce_output: bool = True, # ROW parallel only skip_create_weights_in_init: bool = False, use_custom_cublas_mm: bool = False, + lora: Optional[LoraLayer] = None, ): from ..distributed import AllReduce @@ -197,6 +199,7 @@ class Linear(nn.Module): self._weights_created = False self.reduce_output = reduce_output self.use_custom_cublas_mm = use_custom_cublas_mm + self.lora = lora if not skip_create_weights_in_init: self.create_weights() @@ -310,7 +313,12 @@ class Linear(nn.Module): self.register_parameter("bias", None) self._weights_created = True - def apply_linear(self, input, weight, bias): + def apply_linear(self, + input, + weight, + bias, + lora_params: Optional[dict] | None = None, + layer_idx: Optional[int] | None = None): if self.has_any_quant: qc = self.quant_config if self.has_fp8_qdq: @@ -368,6 +376,12 @@ class Linear(nn.Module): out_dtype=None) else: output = F.linear(input, self.weight, bias) + + if self.lora is not None and bool(lora_params): + lora_result = self.lora(input, lora_params, layer_idx) + if lora_result is not None: + output = output + lora_result + return output def _maybe_fuse_bias_into_allreduce( @@ -392,6 +406,8 @@ class Linear(nn.Module): input: Union[torch.Tensor, Fp4QuantizedTensor], *, all_reduce_params: Optional[AllReduceParams] = None, + lora_params: Optional[dict] = None, + layer_idx: Optional[int] = None, ) -> torch.Tensor: from ..distributed import allgather @@ -401,19 +417,23 @@ class Linear(nn.Module): fuse_bias = self._maybe_fuse_bias_into_allreduce( bias, all_reduce_params) bias = None if fuse_bias else bias - output = self.apply_linear(input, self.weight, bias) + output = self.apply_linear(input, self.weight, bias, + lora_params, layer_idx) output = self.all_reduce( output, all_reduce_params=all_reduce_params, ) else: - output = self.apply_linear(input, self.weight, bias) + output = self.apply_linear(input, self.weight, bias, + lora_params, layer_idx) elif self.tp_mode == TensorParallelMode.COLUMN: - output = self.apply_linear(input, self.weight, self.bias) + output = self.apply_linear(input, self.weight, self.bias, + lora_params, layer_idx) if self.gather_output: output = allgather(output, self.mapping) else: - output = self.apply_linear(input, self.weight, self.bias) + output = self.apply_linear(input, self.weight, self.bias, + lora_params, layer_idx) return output diff --git a/tensorrt_llm/_torch/modules/mlp.py b/tensorrt_llm/_torch/modules/mlp.py index 842a942ca2..8d026e1fa2 100644 --- a/tensorrt_llm/_torch/modules/mlp.py +++ b/tensorrt_llm/_torch/modules/mlp.py @@ -28,6 +28,10 @@ class MLP(nn.Module): self.activation = activation config = config or ModelConfig() + self.up_lora = LoraLayer( + [LoraModuleType.MLP_H_TO_4H], + [self.intermediate_size // config.mapping.tp_size]) + self.up_proj = Linear( self.hidden_size, self.intermediate_size, @@ -38,8 +42,11 @@ class MLP(nn.Module): weights_loading_config=WeightsLoadingConfig( weight_mode=WeightMode.VANILLA), quant_config=config.get_quant_config(), - skip_create_weights_in_init=config.skip_create_weights_in_init) + skip_create_weights_in_init=config.skip_create_weights_in_init, + lora=self.up_lora) + self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], + [self.hidden_size]) self.down_proj = Linear( self.intermediate_size, self.hidden_size, @@ -48,12 +55,8 @@ class MLP(nn.Module): mapping=config.mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=config.get_quant_config(), - skip_create_weights_in_init=config.skip_create_weights_in_init) - - self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H], - [self.intermediate_size]) - self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H], - [self.hidden_size]) + skip_create_weights_in_init=config.skip_create_weights_in_init, + lora=self.down_lora) def forward( self, @@ -84,10 +87,8 @@ class MLP(nn.Module): x_up = x_up + x_up_lora x_act = self.activation(x_up) - x_down = self.down_proj(x_act) - - x_down_lora = self.down_lora(x_act, lora_params, self.layer_idx) - if x_down_lora is not None: - x_down = x_down + x_down_lora + x_down = self.down_proj(x_act, + lora_params=lora_params, + layer_idx=self.layer_idx) return x_down diff --git a/tensorrt_llm/_torch/peft/lora/layer.py b/tensorrt_llm/_torch/peft/lora/layer.py index e9ce03dae4..fb98461417 100644 --- a/tensorrt_llm/_torch/peft/lora/layer.py +++ b/tensorrt_llm/_torch/peft/lora/layer.py @@ -92,8 +92,12 @@ class LoraLayer(torch.nn.Module): self.output_hidden_sizes = output_hidden_sizes assert len(lora_module_types) == len(output_hidden_sizes) - def forward(self, x, lora_params: Dict, - layer_idx: int) -> Optional[torch.Tensor]: + def forward( + self, + x, + lora_params: Dict, + layer_idx: int, + ) -> Optional[torch.Tensor]: if bool(lora_params): lora_ranks = [] @@ -147,7 +151,6 @@ class LoraLayer(torch.nn.Module): ], dtype=x.dtype, device=x.device)) - # TODO smor should be by: dim=q_lora.rank() - 1 lora_output = torch.cat(lora_output, dim=-1) return lora_output diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4edbb2e94f..865975b4c4 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -10,7 +10,9 @@ import tensorrt_llm.bindings.executor as trtllm from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig from tensorrt_llm.logger import logger -from tensorrt_llm.lora_manager import LoraConfig, load_torch_hf_lora +from tensorrt_llm.lora_manager import (LoraConfig, + get_default_trtllm_modules_to_hf_modules, + load_torch_hf_lora) from tensorrt_llm.mapping import Mapping from ..speculative import get_num_spec_layers, get_spec_decoder @@ -380,7 +382,16 @@ def create_py_executor_instance(dist, if lora_config is not None: from tensorrt_llm.bindings import LoraModule - load_torch_hf_lora(lora_config) + + if len(lora_config.lora_dir) == 1: + load_torch_hf_lora(lora_config) + else: + assert len(lora_config.lora_target_modules + ) >= 1, "Expecting at least one lora target module" + if not bool(lora_config.trtllm_modules_to_hf_modules): + lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules( + ) + model_binding_config = model_engine.model.model_config.get_bindings_model_config( ) lora_modules = LoraModule.create_lora_modules( @@ -406,9 +417,19 @@ def create_py_executor_instance(dist, max_cpu_loras, ) + from tensorrt_llm.bindings import WorldConfig + world_config = WorldConfig( + tensor_parallelism=mapping.tp_size, + pipeline_parallelism=mapping.pp_size, + context_parallelism=mapping.cp_size, + rank=dist.mapping.rank, + gpus_per_node=dist.mapping.gpus_per_node, + ) peft_cache_manager = PeftCacheManager( peft_cache_config=executor_config.peft_cache_config, - model_config=model_binding_config) + model_config=model_binding_config, + world_config=world_config, + ) resources["peft_cache_manager"] = peft_cache_manager model_engine.set_lora_model_config( lora_config.lora_target_modules, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index d14d2b1548..893481d9ac 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -30,6 +30,7 @@ KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEve RequestList = list[LlmRequest] PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig +WorldConfig = tensorrt_llm.bindings.WorldConfig def compute_page_count(token_count: int, tokens_per_page: int) -> int: @@ -715,8 +716,10 @@ class ResourceManager: class PeftCacheManager(BaseResourceManager): - def __init__(self, peft_cache_config: PeftCacheConfig, - model_config: ModelConfig): + def __init__(self, + peft_cache_config: PeftCacheConfig, + model_config: ModelConfig, + world_config: WorldConfig | None = None): import tensorrt_llm.bindings as _tb peft_cache_manager_config = _tb.PeftCacheManagerConfig( @@ -735,8 +738,8 @@ class PeftCacheManager(BaseResourceManager): lora_prefetch_dir=peft_cache_config.lora_prefetch_dir, ) - # TODO smor- currently set manually, change that - world_config = _tb.WorldConfig() + if world_config is None: + world_config = _tb.WorldConfig() BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream, diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index bfa443c1a0..c94456af17 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -634,7 +634,24 @@ class LLM: if hasattr( self.args, "backend" ) and self.args.backend == "pytorch" and self.args.lora_config is not None: - tokenizer_path = self.args.lora_config.lora_dir[0] + num_lora_dirs = len(self.args.lora_config.lora_dir) + if num_lora_dirs == 1: + tokenizer_path = self.args.lora_config.lora_dir[0] + try: + tokenizer = ModelLoader.load_hf_tokenizer( + tokenizer_path, + trust_remote_code=self.args.trust_remote_code, + use_fast=self.args.tokenizer_mode != 'slow') + return tokenizer + except Exception: + tokenizer_path = self.args.model + elif num_lora_dirs > 1: + # TODO smor- currently not supported, need to determine which tokenizer to use, if possible + raise ValueError( + f"Expecting only a single lora dir, but got {num_lora_dirs}" + ) + else: + tokenizer_path = self.args.model else: tokenizer_path = self.args.model return ModelLoader.load_hf_tokenizer( diff --git a/tensorrt_llm/lora_manager.py b/tensorrt_llm/lora_manager.py index 789edeb886..273f96c0e6 100644 --- a/tensorrt_llm/lora_manager.py +++ b/tensorrt_llm/lora_manager.py @@ -77,7 +77,7 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None): hf_module = m.group(3) + "." + module_name if hf_module not in hf_modules: hf_module = module_name - assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported llist {hf_modules}" + assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}" is_lora_a_or_b = m.group(8) is not None if is_lora_a_or_b: @@ -276,6 +276,7 @@ def load_torch_hf_lora(lora_config: LoraConfig): lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules( ) + assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir" lora_loader = HfLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: diff --git a/tests/unittest/_torch/modeling/test_modeling_llama.py b/tests/unittest/_torch/modeling/test_modeling_llama.py index def70893f4..c3fed8be6f 100644 --- a/tests/unittest/_torch/modeling/test_modeling_llama.py +++ b/tests/unittest/_torch/modeling/test_modeling_llama.py @@ -7,12 +7,10 @@ import torch from parameterized import parameterized from transformers import LlamaConfig from transformers import LlamaForCausalLM as HFLlamaForCausalLM -from utils.llm_data import llm_models_root -from utils.util import getSMVersion, similar, skip_gpu_memory_less_than +from utils.util import getSMVersion import tensorrt_llm from tensorrt_llm._torch.attention_backend.utils import get_attention_backend -from tensorrt_llm._torch.llm import LLM from tensorrt_llm._torch.metadata import KVCacheParams from tensorrt_llm._torch.model_config import ModelConfig from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM @@ -20,11 +18,8 @@ from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \ DecodingCUDAGraphRunner from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager from tensorrt_llm.bindings.executor import KvCacheConfig -from tensorrt_llm.executor.request import LoRARequest -from tensorrt_llm.lora_manager import LoraConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig -from tensorrt_llm.sampling_params import SamplingParams LLAMA_3_1_8B_CONFIG = { "architectures": ["LlamaForCausalLM"], @@ -368,33 +363,3 @@ class TestLlama(unittest.TestCase): rtol=0.4) kv_cache_manager.shutdown() - - @skip_gpu_memory_less_than(40 * 2**30) # 40GB memory - def test_llama_lora(self) -> None: - lora_config = LoraConfig(lora_dir=[ - f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b" - ], - max_lora_rank=64) - llm = LLM( - model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf", - lora_config=lora_config, - ) - - prompts = [ - "今天天气很好,我到公园的时候,", - ] - references = [ - "发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的", - ] - sampling_params = SamplingParams(max_tokens=20, - add_special_tokens=False) - lora_req = LoRARequest( - "task-0", 0, - f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b") - lora_request = [lora_req] - - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_request) - - assert similar(outputs[0].outputs[0].text, references[0]) diff --git a/tests/unittest/_torch/test_resource_manager.py b/tests/unittest/_torch/test_resource_manager.py index 0e4d2043e6..4ef1202a71 100644 --- a/tests/unittest/_torch/test_resource_manager.py +++ b/tests/unittest/_torch/test_resource_manager.py @@ -38,9 +38,6 @@ class TestResourceManager(unittest.TestCase): """ Setup the lora test data resources """ - # TODO smor- this should be ported to a different place, ideally run once - # in a similar way to the cpp tests fixutres. - cpp_script_dir = os.path.join(cls.CPP_RESOURCES_DIR, "scripts") generate_lora_data_args_tp1 = [ @@ -258,7 +255,6 @@ class TestResourceManager(unittest.TestCase): Returns: tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation. """ - # TODO smor- change from custom path configuration to relative path lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16) lora_weights = np.expand_dims(lora_weights, axis=0) lora_config = np.load(self.TP1_CONFIG_PATH) diff --git a/tests/unittest/llmapi/test_llm.py b/tests/unittest/llmapi/test_llm.py index 7369c72fa9..8083ce1b20 100644 --- a/tests/unittest/llmapi/test_llm.py +++ b/tests/unittest/llmapi/test_llm.py @@ -1329,6 +1329,7 @@ def test_executor_lookahead_decoding_config(): def llama_v2_13b_lora_test_harness(**llm_kwargs): + # Shahar- perhaps disable build config hf_model_dir = get_model_path("llama-models-v2/llama-v2-13b-hf") hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b") diff --git a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py index 60678fdf64..2fe3e0c5b5 100644 --- a/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py +++ b/tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py @@ -4,8 +4,14 @@ import pytest from .test_llm import (global_kvcache_config, tinyllama_guided_decoding_test_harness, tinyllama_logits_processor_test_harness) +from tensorrt_llm.llmapi import KvCacheConfig +from .test_llm_pytorch import (llama_v2_13b_lora_test_harness, + llama_7b_multi_lora_test_harness) + # isort: on +global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4) + @pytest.mark.gpu4 def test_tinyllama_guided_decoding_tp2pp2(): @@ -31,3 +37,15 @@ def test_tinyllama_logits_processor_2gpu(tp_size: int, pp_size: int): tinyllama_logits_processor_test_harness(backend="pytorch", tensor_parallel_size=tp_size, pipeline_parallel_size=pp_size) + + +@pytest.mark.gpu2 +def test_llama_v2_13b_lora_tp2(): + llama_v2_13b_lora_test_harness(tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) + + +@pytest.mark.gpu2 +def test_llama_7b_multi_lora_tp2(): + llama_7b_multi_lora_test_harness(tensor_parallel_size=2, + kv_cache_config=global_kv_cache_config) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index bf3ad605e4..c7dbeca1a0 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -4,14 +4,17 @@ from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer from tensorrt_llm.sampling_params import SamplingParams # isort: off -from .test_llm import (get_model_path, global_kvcache_config, llama_model_path, - llm_get_stats_async_test_harness, - llm_get_stats_test_harness, prompts, - run_llm_abort_request, - run_llm_with_postprocess_parallel_and_result_handler, - tinyllama_guided_decoding_test_harness, - tinyllama_logits_processor_test_harness) -from utils.util import force_ampere +from .test_llm import ( + get_model_path, global_kvcache_config, llama_model_path, + llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts, + run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler, + tinyllama_guided_decoding_test_harness, + tinyllama_logits_processor_test_harness, llama_7b_multi_lora_test_harness) +from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb +from utils.llm_data import llm_models_root +from tensorrt_llm.lora_manager import LoraConfig +from tensorrt_llm.executor.request import LoRARequest + # isort: on @@ -89,3 +92,85 @@ def test_llm_with_postprocess_parallel_and_result_handler(streaming): run_llm_with_postprocess_parallel_and_result_handler(streaming, "pytorch", tp_size=1) + + +def llama_v2_13b_lora_test_harness(**llm_kwargs) -> None: + from tensorrt_llm._torch.llm import LLM + + lora_config = LoraConfig(lora_dir=[ + f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b" + ], + max_lora_rank=64) + llm = LLM(model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf", + lora_config=lora_config, + **llm_kwargs) + + prompts = [ + "今天天气很好,我到公园的时候,", + ] + references = [ + "发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的", + ] + sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False) + lora_req = LoRARequest( + "task-0", 0, + f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b") + lora_request = [lora_req] + + outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) + + assert similar(outputs[0].outputs[0].text, references[0]) + + +def llama_7b_multi_lora_test_harness(**llm_kwargs) -> None: + from tensorrt_llm._torch.llm import LLM + + hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf" + hf_lora_dir1 = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1" + hf_lora_dir2 = f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0" + + # For LoRA checkpoints without finetuned embedding and lm_head, we can either: + # (1) specify lora_target_modules, or + # (2) provide a lora_dir to infer the lora_target_modules. + lora_config = LoraConfig(lora_target_modules=['attn_q', 'attn_k', 'attn_v'], + max_lora_rank=8) + llm = LLM(hf_model_dir, + fast_build=True, + lora_config=lora_config, + **llm_kwargs) + + prompts = [ + "美国的首都在哪里? \n答案:", + "美国的首都在哪里? \n答案:", + "美国的首都在哪里? \n答案:", + "アメリカ合衆国の首都はどこですか? \n答え:", + "アメリカ合衆国の首都はどこですか? \n答え:", + "アメリカ合衆国の首都はどこですか? \n答え:", + ] + references = [ + "沃尔玛\n\n## 新闻\n\n* ", + "美国的首都是华盛顿。\n\n美国的", + "纽约\n\n### カンファレンスの", + "Washington, D.C.\nWashington, D.C. is the capital of the United", + "华盛顿。\n\n英国の首都是什", + "ワシントン\nQ1. アメリカ合衆国", + ] + lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1) + lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2) + sampling_params = SamplingParams(max_tokens=20) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=[None, lora_req1, lora_req2, None, lora_req1, lora_req2]) + for output, ref in zip(outputs, references): + assert similar(output.outputs[0].text, ref) + + +@skip_gpu_memory_less_than_40gb +def test_llama_v2_13b_lora(): + llama_v2_13b_lora_test_harness() + + +@skip_gpu_memory_less_than_40gb +def test_llama_7b_multi_lora(): + llama_7b_multi_lora_test_harness()