mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
feat: support multi lora adapters and TP (#3885)
* support multi lora, tp Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
This commit is contained in:
parent
99313af242
commit
7d94c9561f
@ -1,38 +0,0 @@
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm.executor import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
|
||||
|
||||
def main():
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
],
|
||||
max_lora_rank=64)
|
||||
llm = LLM(
|
||||
model=
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf",
|
||||
lora_config=lora_config,
|
||||
)
|
||||
prompts = [
|
||||
"今天天气很好,我到公园的时候,",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
|
||||
lora_req_2 = LoRARequest(
|
||||
"task-0", 0,
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
)
|
||||
lora_request = [lora_req_2]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -168,22 +168,18 @@ class ModelConfig(Generic[TConfig]):
|
||||
quant_config_dict=layer_quant_config,
|
||||
**kwargs)
|
||||
|
||||
def get_bindings_model_config(
|
||||
self,
|
||||
tensor_parallelism: int = 1,
|
||||
context_parallelism: int = 1) -> "ModelConfigCpp":
|
||||
def get_bindings_model_config(self) -> "ModelConfigCpp":
|
||||
"""
|
||||
This method is used to construct the bindings config for the model.
|
||||
Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes
|
||||
that an engine has been created.
|
||||
"""
|
||||
# TODO smor- this isn't robust, and currently tested for LlamaConfig only
|
||||
# TODO smor- currently parallelism is not supported, set default to 1
|
||||
# TODO smor- currently assuming no rnn layers, no MOE
|
||||
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
|
||||
|
||||
num_heads = self.pretrained_config.num_attention_heads // (
|
||||
tensor_parallelism * context_parallelism)
|
||||
self.mapping.tp_size * self.mapping.cp_size)
|
||||
|
||||
model_config_cpp = ModelConfigCpp(
|
||||
vocab_size=self.pretrained_config.vocab_size,
|
||||
@ -195,7 +191,7 @@ class ModelConfig(Generic[TConfig]):
|
||||
data_type=torch_dtype_to_binding(
|
||||
self.pretrained_config.torch_dtype))
|
||||
|
||||
mlp_hidden_size = self.pretrained_config.intermediate_size // tensor_parallelism
|
||||
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
|
||||
if "head_size" in self.pretrained_config:
|
||||
head_size = self.pretrained_config.head_size
|
||||
else:
|
||||
|
||||
@ -13,6 +13,7 @@ from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
|
||||
AllReduceParams, DeepseekAllReduce)
|
||||
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.models.convert_utils import split_matrix_tp
|
||||
|
||||
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
|
||||
register_input_processor)
|
||||
@ -773,13 +774,14 @@ class LlamaModel(DecoderModel):
|
||||
self.padding_idx = config.pad_token_id
|
||||
|
||||
vocab_size = config.vocab_size
|
||||
# TODO smor- hack
|
||||
if hasattr(model_config,
|
||||
'lora_config') and model_config.lora_config is not None:
|
||||
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
|
||||
if hasattr(
|
||||
model_config,
|
||||
'lora_config') and model_config.lora_config is not None and len(
|
||||
model_config.lora_config.lora_dir) == 1:
|
||||
from tensorrt_llm.lora_manager import HfLoraLoader
|
||||
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
|
||||
weight = lora_loader.embed_tokens
|
||||
# TODO smor - need to split tp matrix here
|
||||
vocab_size = lora_loader.vocab_size
|
||||
|
||||
self.embed_tokens = Embedding(
|
||||
@ -791,9 +793,17 @@ class LlamaModel(DecoderModel):
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
if hasattr(model_config,
|
||||
'lora_config') and model_config.lora_config is not None:
|
||||
if hasattr(
|
||||
model_config,
|
||||
'lora_config') and model_config.lora_config is not None and len(
|
||||
model_config.lora_config.lora_dir) == 1:
|
||||
with torch.no_grad():
|
||||
if model_config.mapping.tp_size > 1:
|
||||
weight = split_matrix_tp(
|
||||
weight,
|
||||
model_config.mapping.tp_size,
|
||||
model_config.mapping.tp_rank,
|
||||
dim=0) # split by vocabulary dimension
|
||||
x = weight.to(self.embed_tokens.dtype)
|
||||
self.embed_tokens.weight.data.copy_(x)
|
||||
|
||||
|
||||
@ -11,6 +11,9 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from torch.utils._pytree import tree_any_only
|
||||
from tqdm import tqdm
|
||||
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.convert_utils import split_matrix_tp
|
||||
|
||||
from ...logger import logger
|
||||
from ...mapping import Mapping
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
@ -240,7 +243,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
|
||||
input_ids: torch.LongTensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_params: Optional = None, # TODO smor add type hint
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
@ -357,9 +360,9 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
|
||||
# will considering per layer QuantConfig in the future.
|
||||
|
||||
# TODO smor- hack
|
||||
if hasattr(config,
|
||||
'lora_config') and config.lora_config is not None:
|
||||
if hasattr(config, 'lora_config'
|
||||
) and config.lora_config is not None and len(
|
||||
config.lora_config.lora_dir) == 1:
|
||||
from tensorrt_llm.lora_manager import HfLoraLoader
|
||||
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
|
||||
weight = lora_loader.lm_head
|
||||
@ -374,9 +377,16 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
gather_output=True,
|
||||
)
|
||||
|
||||
if hasattr(config,
|
||||
'lora_config') and config.lora_config is not None:
|
||||
if hasattr(config, 'lora_config'
|
||||
) and config.lora_config is not None and len(
|
||||
config.lora_config.lora_dir) == 1:
|
||||
with torch.no_grad():
|
||||
if config.mapping.tp_size > 1:
|
||||
weight = split_matrix_tp(
|
||||
weight,
|
||||
config.mapping.tp_size,
|
||||
config.mapping.tp_rank,
|
||||
dim=0) # split by vocabulary dimension
|
||||
x = weight.to(self.lm_head.dtype).cuda()
|
||||
self.lm_head.weight.data.copy_(x)
|
||||
|
||||
@ -475,7 +485,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
pipeline_interface: Optional[PipelineInterface] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
lora_params: Optional = None, # TODO smor add type hint
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if self._supports_pp and self.pp_size > 1:
|
||||
@ -657,8 +667,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
|
||||
|
||||
# Skip loading weights for embedding and lm_head if LoRA is enabled
|
||||
if hasattr(model.model_config, 'lora_config'
|
||||
) and model.model_config.lora_config is not None and (
|
||||
name == "model.embed_tokens" or name == "lm_head"):
|
||||
) and model.model_config.lora_config is not None and len(
|
||||
model.model_config.lora_config.lora_dir) == 1 and (
|
||||
name == "model.embed_tokens"
|
||||
or name == "lm_head"):
|
||||
continue
|
||||
|
||||
# Skip if parameter belongs to a missing layer
|
||||
|
||||
@ -88,6 +88,9 @@ class Attention(nn.Module):
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
)
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
|
||||
self.o_proj = Linear(
|
||||
tp_size * self.q_size,
|
||||
self.hidden_size,
|
||||
@ -97,6 +100,7 @@ class Attention(nn.Module):
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
lora=self.o_lora,
|
||||
)
|
||||
self.quant_config = config.get_quant_config()
|
||||
self.attn_backend = config.attn_backend
|
||||
@ -229,13 +233,9 @@ class Attention(nn.Module):
|
||||
mrope_config=mrope_config)
|
||||
hidden_states = attn_output
|
||||
attn_output = self.o_proj(attn_output,
|
||||
all_reduce_params=all_reduce_params)
|
||||
if bool(lora_params):
|
||||
attn_lora_output = self.o_lora(hidden_states, lora_params,
|
||||
self.layer_idx)
|
||||
if attn_lora_output is not None:
|
||||
attn_output = attn_output + attn_lora_output
|
||||
|
||||
all_reduce_params=all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
layer_idx=self.layer_idx)
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
@ -76,6 +76,9 @@ class GatedMLP(nn.Module):
|
||||
reduce_output=False,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
)
|
||||
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
|
||||
[self.hidden_size])
|
||||
|
||||
self.down_proj = Linear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
@ -86,18 +89,20 @@ class GatedMLP(nn.Module):
|
||||
quant_config=config.get_quant_config(),
|
||||
reduce_output=reduce_output,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
lora=self.down_lora,
|
||||
)
|
||||
|
||||
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
|
||||
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
|
||||
# handles them as a single fused operation.
|
||||
self.splitted_gate_up_lora = LoraLayer(
|
||||
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE],
|
||||
[self.intermediate_size, self.intermediate_size])
|
||||
self.fused_gate_up_lora = LoraLayer([LoraModuleType.MLP_GATE_UP],
|
||||
[2 * self.intermediate_size])
|
||||
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
|
||||
[self.hidden_size])
|
||||
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], [
|
||||
self.intermediate_size // mapping.tp_size,
|
||||
self.intermediate_size // mapping.tp_size
|
||||
])
|
||||
self.fused_gate_up_lora = LoraLayer(
|
||||
[LoraModuleType.MLP_GATE_UP],
|
||||
[2 * self.intermediate_size // mapping.tp_size])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -107,33 +112,17 @@ class GatedMLP(nn.Module):
|
||||
lora_params: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if lora_params is not None:
|
||||
if bool(lora_params):
|
||||
return self.forward_lora(x, all_rank_num_tokens,
|
||||
final_all_reduce_params, lora_params)
|
||||
|
||||
if self.activation == F.silu:
|
||||
h1 = self.gate_up_proj(x)
|
||||
if bool(lora_params):
|
||||
assert self.layer_idx is not None, "layer_idx is required for lora"
|
||||
h1_lora = self.splitted_gate_up_lora(x, lora_params,
|
||||
self.layer_idx)
|
||||
if h1_lora is not None:
|
||||
h1 = h1 + h1_lora
|
||||
|
||||
h1_lora = self.fused_gate_up_lora(x, lora_params,
|
||||
self.layer_idx)
|
||||
|
||||
if h1_lora is not None:
|
||||
h1 = h1 + h1_lora
|
||||
|
||||
h2 = swiglu(h1)
|
||||
output = self.down_proj(h2,
|
||||
all_reduce_params=final_all_reduce_params)
|
||||
if bool(lora_params):
|
||||
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
|
||||
if output_lora is not None:
|
||||
output = output + output_lora
|
||||
|
||||
all_reduce_params=final_all_reduce_params,
|
||||
layer_idx=self.layer_idx)
|
||||
return output
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -154,19 +143,18 @@ class GatedMLP(nn.Module):
|
||||
h1 = self.gate_up_proj(x)
|
||||
|
||||
h1_lora = self.splitted_gate_up_lora(x, lora_params, self.layer_idx)
|
||||
|
||||
if h1_lora is not None:
|
||||
h1 = h1 + h1_lora
|
||||
|
||||
h1_lora = self.fused_gate_up_lora(x, lora_params, self.layer_idx)
|
||||
|
||||
if h1_lora is not None:
|
||||
h1 = h1 + h1_lora
|
||||
|
||||
h2 = swiglu(h1)
|
||||
output = self.down_proj(h2, all_reduce_params=final_all_reduce_params)
|
||||
|
||||
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
|
||||
if output_lora is not None:
|
||||
output = output + output_lora
|
||||
output = self.down_proj(h2,
|
||||
all_reduce_params=final_all_reduce_params,
|
||||
lora_params=lora_params,
|
||||
layer_idx=self.layer_idx)
|
||||
|
||||
return output
|
||||
|
||||
@ -9,6 +9,7 @@ from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
@ -157,6 +158,7 @@ class Linear(nn.Module):
|
||||
reduce_output: bool = True, # ROW parallel only
|
||||
skip_create_weights_in_init: bool = False,
|
||||
use_custom_cublas_mm: bool = False,
|
||||
lora: Optional[LoraLayer] = None,
|
||||
):
|
||||
from ..distributed import AllReduce
|
||||
|
||||
@ -197,6 +199,7 @@ class Linear(nn.Module):
|
||||
self._weights_created = False
|
||||
self.reduce_output = reduce_output
|
||||
self.use_custom_cublas_mm = use_custom_cublas_mm
|
||||
self.lora = lora
|
||||
|
||||
if not skip_create_weights_in_init:
|
||||
self.create_weights()
|
||||
@ -310,7 +313,12 @@ class Linear(nn.Module):
|
||||
self.register_parameter("bias", None)
|
||||
self._weights_created = True
|
||||
|
||||
def apply_linear(self, input, weight, bias):
|
||||
def apply_linear(self,
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
lora_params: Optional[dict] | None = None,
|
||||
layer_idx: Optional[int] | None = None):
|
||||
if self.has_any_quant:
|
||||
qc = self.quant_config
|
||||
if self.has_fp8_qdq:
|
||||
@ -368,6 +376,12 @@ class Linear(nn.Module):
|
||||
out_dtype=None)
|
||||
else:
|
||||
output = F.linear(input, self.weight, bias)
|
||||
|
||||
if self.lora is not None and bool(lora_params):
|
||||
lora_result = self.lora(input, lora_params, layer_idx)
|
||||
if lora_result is not None:
|
||||
output = output + lora_result
|
||||
|
||||
return output
|
||||
|
||||
def _maybe_fuse_bias_into_allreduce(
|
||||
@ -392,6 +406,8 @@ class Linear(nn.Module):
|
||||
input: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
*,
|
||||
all_reduce_params: Optional[AllReduceParams] = None,
|
||||
lora_params: Optional[dict] = None,
|
||||
layer_idx: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
from ..distributed import allgather
|
||||
|
||||
@ -401,19 +417,23 @@ class Linear(nn.Module):
|
||||
fuse_bias = self._maybe_fuse_bias_into_allreduce(
|
||||
bias, all_reduce_params)
|
||||
bias = None if fuse_bias else bias
|
||||
output = self.apply_linear(input, self.weight, bias)
|
||||
output = self.apply_linear(input, self.weight, bias,
|
||||
lora_params, layer_idx)
|
||||
output = self.all_reduce(
|
||||
output,
|
||||
all_reduce_params=all_reduce_params,
|
||||
)
|
||||
else:
|
||||
output = self.apply_linear(input, self.weight, bias)
|
||||
output = self.apply_linear(input, self.weight, bias,
|
||||
lora_params, layer_idx)
|
||||
elif self.tp_mode == TensorParallelMode.COLUMN:
|
||||
output = self.apply_linear(input, self.weight, self.bias)
|
||||
output = self.apply_linear(input, self.weight, self.bias,
|
||||
lora_params, layer_idx)
|
||||
if self.gather_output:
|
||||
output = allgather(output, self.mapping)
|
||||
else:
|
||||
output = self.apply_linear(input, self.weight, self.bias)
|
||||
output = self.apply_linear(input, self.weight, self.bias,
|
||||
lora_params, layer_idx)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@ -28,6 +28,10 @@ class MLP(nn.Module):
|
||||
self.activation = activation
|
||||
|
||||
config = config or ModelConfig()
|
||||
self.up_lora = LoraLayer(
|
||||
[LoraModuleType.MLP_H_TO_4H],
|
||||
[self.intermediate_size // config.mapping.tp_size])
|
||||
|
||||
self.up_proj = Linear(
|
||||
self.hidden_size,
|
||||
self.intermediate_size,
|
||||
@ -38,8 +42,11 @@ class MLP(nn.Module):
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.VANILLA),
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init)
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
lora=self.up_lora)
|
||||
|
||||
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
|
||||
[self.hidden_size])
|
||||
self.down_proj = Linear(
|
||||
self.intermediate_size,
|
||||
self.hidden_size,
|
||||
@ -48,12 +55,8 @@ class MLP(nn.Module):
|
||||
mapping=config.mapping,
|
||||
tensor_parallel_mode=TensorParallelMode.ROW,
|
||||
quant_config=config.get_quant_config(),
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init)
|
||||
|
||||
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
|
||||
[self.intermediate_size])
|
||||
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
|
||||
[self.hidden_size])
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
lora=self.down_lora)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -84,10 +87,8 @@ class MLP(nn.Module):
|
||||
x_up = x_up + x_up_lora
|
||||
|
||||
x_act = self.activation(x_up)
|
||||
x_down = self.down_proj(x_act)
|
||||
|
||||
x_down_lora = self.down_lora(x_act, lora_params, self.layer_idx)
|
||||
if x_down_lora is not None:
|
||||
x_down = x_down + x_down_lora
|
||||
x_down = self.down_proj(x_act,
|
||||
lora_params=lora_params,
|
||||
layer_idx=self.layer_idx)
|
||||
|
||||
return x_down
|
||||
|
||||
@ -92,8 +92,12 @@ class LoraLayer(torch.nn.Module):
|
||||
self.output_hidden_sizes = output_hidden_sizes
|
||||
assert len(lora_module_types) == len(output_hidden_sizes)
|
||||
|
||||
def forward(self, x, lora_params: Dict,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
lora_params: Dict,
|
||||
layer_idx: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
||||
if bool(lora_params):
|
||||
lora_ranks = []
|
||||
@ -147,7 +151,6 @@ class LoraLayer(torch.nn.Module):
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device))
|
||||
# TODO smor should be by: dim=q_lora.rank() - 1
|
||||
lora_output = torch.cat(lora_output, dim=-1)
|
||||
return lora_output
|
||||
|
||||
|
||||
@ -10,7 +10,9 @@ import tensorrt_llm.bindings.executor as trtllm
|
||||
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
|
||||
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.lora_manager import LoraConfig, load_torch_hf_lora
|
||||
from tensorrt_llm.lora_manager import (LoraConfig,
|
||||
get_default_trtllm_modules_to_hf_modules,
|
||||
load_torch_hf_lora)
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..speculative import get_num_spec_layers, get_spec_decoder
|
||||
@ -380,7 +382,16 @@ def create_py_executor_instance(dist,
|
||||
|
||||
if lora_config is not None:
|
||||
from tensorrt_llm.bindings import LoraModule
|
||||
load_torch_hf_lora(lora_config)
|
||||
|
||||
if len(lora_config.lora_dir) == 1:
|
||||
load_torch_hf_lora(lora_config)
|
||||
else:
|
||||
assert len(lora_config.lora_target_modules
|
||||
) >= 1, "Expecting at least one lora target module"
|
||||
if not bool(lora_config.trtllm_modules_to_hf_modules):
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
|
||||
)
|
||||
|
||||
model_binding_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
)
|
||||
lora_modules = LoraModule.create_lora_modules(
|
||||
@ -406,9 +417,19 @@ def create_py_executor_instance(dist,
|
||||
max_cpu_loras,
|
||||
)
|
||||
|
||||
from tensorrt_llm.bindings import WorldConfig
|
||||
world_config = WorldConfig(
|
||||
tensor_parallelism=mapping.tp_size,
|
||||
pipeline_parallelism=mapping.pp_size,
|
||||
context_parallelism=mapping.cp_size,
|
||||
rank=dist.mapping.rank,
|
||||
gpus_per_node=dist.mapping.gpus_per_node,
|
||||
)
|
||||
peft_cache_manager = PeftCacheManager(
|
||||
peft_cache_config=executor_config.peft_cache_config,
|
||||
model_config=model_binding_config)
|
||||
model_config=model_binding_config,
|
||||
world_config=world_config,
|
||||
)
|
||||
resources["peft_cache_manager"] = peft_cache_manager
|
||||
model_engine.set_lora_model_config(
|
||||
lora_config.lora_target_modules,
|
||||
|
||||
@ -30,6 +30,7 @@ KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEve
|
||||
RequestList = list[LlmRequest]
|
||||
PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager
|
||||
PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig
|
||||
WorldConfig = tensorrt_llm.bindings.WorldConfig
|
||||
|
||||
|
||||
def compute_page_count(token_count: int, tokens_per_page: int) -> int:
|
||||
@ -715,8 +716,10 @@ class ResourceManager:
|
||||
|
||||
class PeftCacheManager(BaseResourceManager):
|
||||
|
||||
def __init__(self, peft_cache_config: PeftCacheConfig,
|
||||
model_config: ModelConfig):
|
||||
def __init__(self,
|
||||
peft_cache_config: PeftCacheConfig,
|
||||
model_config: ModelConfig,
|
||||
world_config: WorldConfig | None = None):
|
||||
import tensorrt_llm.bindings as _tb
|
||||
|
||||
peft_cache_manager_config = _tb.PeftCacheManagerConfig(
|
||||
@ -735,8 +738,8 @@ class PeftCacheManager(BaseResourceManager):
|
||||
lora_prefetch_dir=peft_cache_config.lora_prefetch_dir,
|
||||
)
|
||||
|
||||
# TODO smor- currently set manually, change that
|
||||
world_config = _tb.WorldConfig()
|
||||
if world_config is None:
|
||||
world_config = _tb.WorldConfig()
|
||||
|
||||
BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager
|
||||
buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream,
|
||||
|
||||
@ -634,7 +634,24 @@ class LLM:
|
||||
if hasattr(
|
||||
self.args, "backend"
|
||||
) and self.args.backend == "pytorch" and self.args.lora_config is not None:
|
||||
tokenizer_path = self.args.lora_config.lora_dir[0]
|
||||
num_lora_dirs = len(self.args.lora_config.lora_dir)
|
||||
if num_lora_dirs == 1:
|
||||
tokenizer_path = self.args.lora_config.lora_dir[0]
|
||||
try:
|
||||
tokenizer = ModelLoader.load_hf_tokenizer(
|
||||
tokenizer_path,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.tokenizer_mode != 'slow')
|
||||
return tokenizer
|
||||
except Exception:
|
||||
tokenizer_path = self.args.model
|
||||
elif num_lora_dirs > 1:
|
||||
# TODO smor- currently not supported, need to determine which tokenizer to use, if possible
|
||||
raise ValueError(
|
||||
f"Expecting only a single lora dir, but got {num_lora_dirs}"
|
||||
)
|
||||
else:
|
||||
tokenizer_path = self.args.model
|
||||
else:
|
||||
tokenizer_path = self.args.model
|
||||
return ModelLoader.load_hf_tokenizer(
|
||||
|
||||
@ -77,7 +77,7 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
|
||||
hf_module = m.group(3) + "." + module_name
|
||||
if hf_module not in hf_modules:
|
||||
hf_module = module_name
|
||||
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported llist {hf_modules}"
|
||||
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}"
|
||||
|
||||
is_lora_a_or_b = m.group(8) is not None
|
||||
if is_lora_a_or_b:
|
||||
@ -276,6 +276,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
|
||||
)
|
||||
|
||||
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
|
||||
lora_loader = HfLoraLoader(lora_config.lora_dir)
|
||||
|
||||
if len(lora_config.lora_target_modules) == 0:
|
||||
|
||||
@ -7,12 +7,10 @@ import torch
|
||||
from parameterized import parameterized
|
||||
from transformers import LlamaConfig
|
||||
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import getSMVersion, similar, skip_gpu_memory_less_than
|
||||
from utils.util import getSMVersion
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
from tensorrt_llm._torch.llm import LLM
|
||||
from tensorrt_llm._torch.metadata import KVCacheParams
|
||||
from tensorrt_llm._torch.model_config import ModelConfig
|
||||
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
|
||||
@ -20,11 +18,8 @@ from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
|
||||
DecodingCUDAGraphRunner
|
||||
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
LLAMA_3_1_8B_CONFIG = {
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
@ -368,33 +363,3 @@ class TestLlama(unittest.TestCase):
|
||||
rtol=0.4)
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@skip_gpu_memory_less_than(40 * 2**30) # 40GB memory
|
||||
def test_llama_lora(self) -> None:
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
],
|
||||
max_lora_rank=64)
|
||||
llm = LLM(
|
||||
model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf",
|
||||
lora_config=lora_config,
|
||||
)
|
||||
|
||||
prompts = [
|
||||
"今天天气很好,我到公园的时候,",
|
||||
]
|
||||
references = [
|
||||
"发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=20,
|
||||
add_special_tokens=False)
|
||||
lora_req = LoRARequest(
|
||||
"task-0", 0,
|
||||
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b")
|
||||
lora_request = [lora_req]
|
||||
|
||||
outputs = llm.generate(prompts,
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
assert similar(outputs[0].outputs[0].text, references[0])
|
||||
|
||||
@ -38,9 +38,6 @@ class TestResourceManager(unittest.TestCase):
|
||||
"""
|
||||
Setup the lora test data resources
|
||||
"""
|
||||
# TODO smor- this should be ported to a different place, ideally run once
|
||||
# in a similar way to the cpp tests fixutres.
|
||||
|
||||
cpp_script_dir = os.path.join(cls.CPP_RESOURCES_DIR, "scripts")
|
||||
|
||||
generate_lora_data_args_tp1 = [
|
||||
@ -258,7 +255,6 @@ class TestResourceManager(unittest.TestCase):
|
||||
Returns:
|
||||
tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation.
|
||||
"""
|
||||
# TODO smor- change from custom path configuration to relative path
|
||||
lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16)
|
||||
lora_weights = np.expand_dims(lora_weights, axis=0)
|
||||
lora_config = np.load(self.TP1_CONFIG_PATH)
|
||||
|
||||
@ -1329,6 +1329,7 @@ def test_executor_lookahead_decoding_config():
|
||||
|
||||
|
||||
def llama_v2_13b_lora_test_harness(**llm_kwargs):
|
||||
# Shahar- perhaps disable build config
|
||||
hf_model_dir = get_model_path("llama-models-v2/llama-v2-13b-hf")
|
||||
hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b")
|
||||
|
||||
|
||||
@ -4,8 +4,14 @@ import pytest
|
||||
from .test_llm import (global_kvcache_config,
|
||||
tinyllama_guided_decoding_test_harness,
|
||||
tinyllama_logits_processor_test_harness)
|
||||
from tensorrt_llm.llmapi import KvCacheConfig
|
||||
from .test_llm_pytorch import (llama_v2_13b_lora_test_harness,
|
||||
llama_7b_multi_lora_test_harness)
|
||||
|
||||
# isort: on
|
||||
|
||||
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
|
||||
|
||||
|
||||
@pytest.mark.gpu4
|
||||
def test_tinyllama_guided_decoding_tp2pp2():
|
||||
@ -31,3 +37,15 @@ def test_tinyllama_logits_processor_2gpu(tp_size: int, pp_size: int):
|
||||
tinyllama_logits_processor_test_harness(backend="pytorch",
|
||||
tensor_parallel_size=tp_size,
|
||||
pipeline_parallel_size=pp_size)
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_llama_v2_13b_lora_tp2():
|
||||
llama_v2_13b_lora_test_harness(tensor_parallel_size=2,
|
||||
kv_cache_config=global_kv_cache_config)
|
||||
|
||||
|
||||
@pytest.mark.gpu2
|
||||
def test_llama_7b_multi_lora_tp2():
|
||||
llama_7b_multi_lora_test_harness(tensor_parallel_size=2,
|
||||
kv_cache_config=global_kv_cache_config)
|
||||
|
||||
@ -4,14 +4,17 @@ from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
|
||||
from tensorrt_llm.sampling_params import SamplingParams
|
||||
|
||||
# isort: off
|
||||
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
|
||||
llm_get_stats_async_test_harness,
|
||||
llm_get_stats_test_harness, prompts,
|
||||
run_llm_abort_request,
|
||||
run_llm_with_postprocess_parallel_and_result_handler,
|
||||
tinyllama_guided_decoding_test_harness,
|
||||
tinyllama_logits_processor_test_harness)
|
||||
from utils.util import force_ampere
|
||||
from .test_llm import (
|
||||
get_model_path, global_kvcache_config, llama_model_path,
|
||||
llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts,
|
||||
run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler,
|
||||
tinyllama_guided_decoding_test_harness,
|
||||
tinyllama_logits_processor_test_harness, llama_7b_multi_lora_test_harness)
|
||||
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb
|
||||
from utils.llm_data import llm_models_root
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
from tensorrt_llm.executor.request import LoRARequest
|
||||
|
||||
# isort: on
|
||||
|
||||
|
||||
@ -89,3 +92,85 @@ def test_llm_with_postprocess_parallel_and_result_handler(streaming):
|
||||
run_llm_with_postprocess_parallel_and_result_handler(streaming,
|
||||
"pytorch",
|
||||
tp_size=1)
|
||||
|
||||
|
||||
def llama_v2_13b_lora_test_harness(**llm_kwargs) -> None:
|
||||
from tensorrt_llm._torch.llm import LLM
|
||||
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
],
|
||||
max_lora_rank=64)
|
||||
llm = LLM(model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf",
|
||||
lora_config=lora_config,
|
||||
**llm_kwargs)
|
||||
|
||||
prompts = [
|
||||
"今天天气很好,我到公园的时候,",
|
||||
]
|
||||
references = [
|
||||
"发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的",
|
||||
]
|
||||
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
|
||||
lora_req = LoRARequest(
|
||||
"task-0", 0,
|
||||
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b")
|
||||
lora_request = [lora_req]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
|
||||
|
||||
assert similar(outputs[0].outputs[0].text, references[0])
|
||||
|
||||
|
||||
def llama_7b_multi_lora_test_harness(**llm_kwargs) -> None:
|
||||
from tensorrt_llm._torch.llm import LLM
|
||||
|
||||
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
|
||||
hf_lora_dir1 = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
|
||||
hf_lora_dir2 = f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
|
||||
|
||||
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
|
||||
# (1) specify lora_target_modules, or
|
||||
# (2) provide a lora_dir to infer the lora_target_modules.
|
||||
lora_config = LoraConfig(lora_target_modules=['attn_q', 'attn_k', 'attn_v'],
|
||||
max_lora_rank=8)
|
||||
llm = LLM(hf_model_dir,
|
||||
fast_build=True,
|
||||
lora_config=lora_config,
|
||||
**llm_kwargs)
|
||||
|
||||
prompts = [
|
||||
"美国的首都在哪里? \n答案:",
|
||||
"美国的首都在哪里? \n答案:",
|
||||
"美国的首都在哪里? \n答案:",
|
||||
"アメリカ合衆国の首都はどこですか? \n答え:",
|
||||
"アメリカ合衆国の首都はどこですか? \n答え:",
|
||||
"アメリカ合衆国の首都はどこですか? \n答え:",
|
||||
]
|
||||
references = [
|
||||
"沃尔玛\n\n## 新闻\n\n* ",
|
||||
"美国的首都是华盛顿。\n\n美国的",
|
||||
"纽约\n\n### カンファレンスの",
|
||||
"Washington, D.C.\nWashington, D.C. is the capital of the United",
|
||||
"华盛顿。\n\n英国の首都是什",
|
||||
"ワシントン\nQ1. アメリカ合衆国",
|
||||
]
|
||||
lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1)
|
||||
lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2)
|
||||
sampling_params = SamplingParams(max_tokens=20)
|
||||
outputs = llm.generate(
|
||||
prompts,
|
||||
sampling_params,
|
||||
lora_request=[None, lora_req1, lora_req2, None, lora_req1, lora_req2])
|
||||
for output, ref in zip(outputs, references):
|
||||
assert similar(output.outputs[0].text, ref)
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
def test_llama_v2_13b_lora():
|
||||
llama_v2_13b_lora_test_harness()
|
||||
|
||||
|
||||
@skip_gpu_memory_less_than_40gb
|
||||
def test_llama_7b_multi_lora():
|
||||
llama_7b_multi_lora_test_harness()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user