feat: support multi lora adapters and TP (#3885)

* support multi lora, tp

Signed-off-by: Shahar Mor <17088876+shaharmor98@users.noreply.github.com>
This commit is contained in:
shaharmor98 2025-05-08 18:45:45 +03:00 committed by GitHub
parent 99313af242
commit 7d94c9561f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 274 additions and 175 deletions

View File

@ -1,38 +0,0 @@
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
def main():
lora_config = LoraConfig(lora_dir=[
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
],
max_lora_rank=64)
llm = LLM(
model=
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf",
lora_config=lora_config,
)
prompts = [
"今天天气很好,我到公园的时候,",
]
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
lora_req_2 = LoRARequest(
"task-0", 0,
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
)
lora_request = [lora_req_2]
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
if __name__ == '__main__':
main()

View File

@ -168,22 +168,18 @@ class ModelConfig(Generic[TConfig]):
quant_config_dict=layer_quant_config,
**kwargs)
def get_bindings_model_config(
self,
tensor_parallelism: int = 1,
context_parallelism: int = 1) -> "ModelConfigCpp":
def get_bindings_model_config(self) -> "ModelConfigCpp":
"""
This method is used to construct the bindings config for the model.
Currently it adheres to gptJsonConfig.cpp::createModelConfig, which assumes
that an engine has been created.
"""
# TODO smor- this isn't robust, and currently tested for LlamaConfig only
# TODO smor- currently parallelism is not supported, set default to 1
# TODO smor- currently assuming no rnn layers, no MOE
from tensorrt_llm.bindings import ModelConfig as ModelConfigCpp
num_heads = self.pretrained_config.num_attention_heads // (
tensor_parallelism * context_parallelism)
self.mapping.tp_size * self.mapping.cp_size)
model_config_cpp = ModelConfigCpp(
vocab_size=self.pretrained_config.vocab_size,
@ -195,7 +191,7 @@ class ModelConfig(Generic[TConfig]):
data_type=torch_dtype_to_binding(
self.pretrained_config.torch_dtype))
mlp_hidden_size = self.pretrained_config.intermediate_size // tensor_parallelism
mlp_hidden_size = self.pretrained_config.intermediate_size // self.mapping.tp_size
if "head_size" in self.pretrained_config:
head_size = self.pretrained_config.head_size
else:

View File

@ -13,6 +13,7 @@ from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
AllReduceParams, DeepseekAllReduce)
from tensorrt_llm._torch.pipeline_interface import PipelineInterface
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.models.convert_utils import split_matrix_tp
from ...inputs import (ExtraProcessedInputs, InputProcessor, TextPrompt,
register_input_processor)
@ -773,13 +774,14 @@ class LlamaModel(DecoderModel):
self.padding_idx = config.pad_token_id
vocab_size = config.vocab_size
# TODO smor- hack
if hasattr(model_config,
'lora_config') and model_config.lora_config is not None:
# TODO smor- we load manually only if there is a single lora dir, need to come up with a better solution
if hasattr(
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
from tensorrt_llm.lora_manager import HfLoraLoader
lora_loader = HfLoraLoader(model_config.lora_config.lora_dir)
weight = lora_loader.embed_tokens
# TODO smor - need to split tp matrix here
vocab_size = lora_loader.vocab_size
self.embed_tokens = Embedding(
@ -791,9 +793,17 @@ class LlamaModel(DecoderModel):
gather_output=True,
)
if hasattr(model_config,
'lora_config') and model_config.lora_config is not None:
if hasattr(
model_config,
'lora_config') and model_config.lora_config is not None and len(
model_config.lora_config.lora_dir) == 1:
with torch.no_grad():
if model_config.mapping.tp_size > 1:
weight = split_matrix_tp(
weight,
model_config.mapping.tp_size,
model_config.mapping.tp_rank,
dim=0) # split by vocabulary dimension
x = weight.to(self.embed_tokens.dtype)
self.embed_tokens.weight.data.copy_(x)

View File

@ -11,6 +11,9 @@ from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_any_only
from tqdm import tqdm
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split_matrix_tp
from ...logger import logger
from ...mapping import Mapping
from ...models.modeling_utils import QuantConfig
@ -240,7 +243,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params: Optional = None, # TODO smor add type hint
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
@ -357,9 +360,9 @@ class DecoderModelForCausalLM(nn.Module,
# TODO(zhenhuanc): Currently lm_head Linear will not accept QuantConfig
# will considering per layer QuantConfig in the future.
# TODO smor- hack
if hasattr(config,
'lora_config') and config.lora_config is not None:
if hasattr(config, 'lora_config'
) and config.lora_config is not None and len(
config.lora_config.lora_dir) == 1:
from tensorrt_llm.lora_manager import HfLoraLoader
lora_loader = HfLoraLoader(config.lora_config.lora_dir)
weight = lora_loader.lm_head
@ -374,9 +377,16 @@ class DecoderModelForCausalLM(nn.Module,
gather_output=True,
)
if hasattr(config,
'lora_config') and config.lora_config is not None:
if hasattr(config, 'lora_config'
) and config.lora_config is not None and len(
config.lora_config.lora_dir) == 1:
with torch.no_grad():
if config.mapping.tp_size > 1:
weight = split_matrix_tp(
weight,
config.mapping.tp_size,
config.mapping.tp_rank,
dim=0) # split by vocabulary dimension
x = weight.to(self.lm_head.dtype).cuda()
self.lm_head.weight.data.copy_(x)
@ -475,7 +485,7 @@ class DecoderModelForCausalLM(nn.Module,
pipeline_interface: Optional[PipelineInterface] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
lora_params: Optional = None, # TODO smor add type hint
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if self._supports_pp and self.pp_size > 1:
@ -657,8 +667,10 @@ def _load_weights_impl(model: Union[nn.Module, DecoderModelForCausalLM],
# Skip loading weights for embedding and lm_head if LoRA is enabled
if hasattr(model.model_config, 'lora_config'
) and model.model_config.lora_config is not None and (
name == "model.embed_tokens" or name == "lm_head"):
) and model.model_config.lora_config is not None and len(
model.model_config.lora_config.lora_dir) == 1 and (
name == "model.embed_tokens"
or name == "lm_head"):
continue
# Skip if parameter belongs to a missing layer

View File

@ -88,6 +88,9 @@ class Attention(nn.Module):
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
self.o_proj = Linear(
tp_size * self.q_size,
self.hidden_size,
@ -97,6 +100,7 @@ class Attention(nn.Module):
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
@ -229,13 +233,9 @@ class Attention(nn.Module):
mrope_config=mrope_config)
hidden_states = attn_output
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
if bool(lora_params):
attn_lora_output = self.o_lora(hidden_states, lora_params,
self.layer_idx)
if attn_lora_output is not None:
attn_output = attn_output + attn_lora_output
all_reduce_params=all_reduce_params,
lora_params=lora_params,
layer_idx=self.layer_idx)
return attn_output

View File

@ -76,6 +76,9 @@ class GatedMLP(nn.Module):
reduce_output=False,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
@ -86,18 +89,20 @@ class GatedMLP(nn.Module):
quant_config=config.get_quant_config(),
reduce_output=reduce_output,
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora,
)
# These two modules are mutually exclusive - either splitted_gate_up_lora or fused_gate_up_lora will be used,
# but never both at the same time. splitted_gate_up_lora handles gate and up separately while fused_gate_up_lora
# handles them as a single fused operation.
self.splitted_gate_up_lora = LoraLayer(
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE],
[self.intermediate_size, self.intermediate_size])
self.fused_gate_up_lora = LoraLayer([LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size])
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
[LoraModuleType.MLP_H_TO_4H, LoraModuleType.MLP_GATE], [
self.intermediate_size // mapping.tp_size,
self.intermediate_size // mapping.tp_size
])
self.fused_gate_up_lora = LoraLayer(
[LoraModuleType.MLP_GATE_UP],
[2 * self.intermediate_size // mapping.tp_size])
def forward(
self,
@ -107,33 +112,17 @@ class GatedMLP(nn.Module):
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:
if lora_params is not None:
if bool(lora_params):
return self.forward_lora(x, all_rank_num_tokens,
final_all_reduce_params, lora_params)
if self.activation == F.silu:
h1 = self.gate_up_proj(x)
if bool(lora_params):
assert self.layer_idx is not None, "layer_idx is required for lora"
h1_lora = self.splitted_gate_up_lora(x, lora_params,
self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h1_lora = self.fused_gate_up_lora(x, lora_params,
self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h2 = swiglu(h1)
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params)
if bool(lora_params):
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
if output_lora is not None:
output = output + output_lora
all_reduce_params=final_all_reduce_params,
layer_idx=self.layer_idx)
return output
else:
raise NotImplementedError(
@ -154,19 +143,18 @@ class GatedMLP(nn.Module):
h1 = self.gate_up_proj(x)
h1_lora = self.splitted_gate_up_lora(x, lora_params, self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h1_lora = self.fused_gate_up_lora(x, lora_params, self.layer_idx)
if h1_lora is not None:
h1 = h1 + h1_lora
h2 = swiglu(h1)
output = self.down_proj(h2, all_reduce_params=final_all_reduce_params)
output_lora = self.down_lora(h2, lora_params, self.layer_idx)
if output_lora is not None:
output = output + output_lora
output = self.down_proj(h2,
all_reduce_params=final_all_reduce_params,
lora_params=lora_params,
layer_idx=self.layer_idx)
return output

View File

@ -9,6 +9,7 @@ from torch import nn
from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from tensorrt_llm.mapping import Mapping
@ -157,6 +158,7 @@ class Linear(nn.Module):
reduce_output: bool = True, # ROW parallel only
skip_create_weights_in_init: bool = False,
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
):
from ..distributed import AllReduce
@ -197,6 +199,7 @@ class Linear(nn.Module):
self._weights_created = False
self.reduce_output = reduce_output
self.use_custom_cublas_mm = use_custom_cublas_mm
self.lora = lora
if not skip_create_weights_in_init:
self.create_weights()
@ -310,7 +313,12 @@ class Linear(nn.Module):
self.register_parameter("bias", None)
self._weights_created = True
def apply_linear(self, input, weight, bias):
def apply_linear(self,
input,
weight,
bias,
lora_params: Optional[dict] | None = None,
layer_idx: Optional[int] | None = None):
if self.has_any_quant:
qc = self.quant_config
if self.has_fp8_qdq:
@ -368,6 +376,12 @@ class Linear(nn.Module):
out_dtype=None)
else:
output = F.linear(input, self.weight, bias)
if self.lora is not None and bool(lora_params):
lora_result = self.lora(input, lora_params, layer_idx)
if lora_result is not None:
output = output + lora_result
return output
def _maybe_fuse_bias_into_allreduce(
@ -392,6 +406,8 @@ class Linear(nn.Module):
input: Union[torch.Tensor, Fp4QuantizedTensor],
*,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
layer_idx: Optional[int] = None,
) -> torch.Tensor:
from ..distributed import allgather
@ -401,19 +417,23 @@ class Linear(nn.Module):
fuse_bias = self._maybe_fuse_bias_into_allreduce(
bias, all_reduce_params)
bias = None if fuse_bias else bias
output = self.apply_linear(input, self.weight, bias)
output = self.apply_linear(input, self.weight, bias,
lora_params, layer_idx)
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
else:
output = self.apply_linear(input, self.weight, bias)
output = self.apply_linear(input, self.weight, bias,
lora_params, layer_idx)
elif self.tp_mode == TensorParallelMode.COLUMN:
output = self.apply_linear(input, self.weight, self.bias)
output = self.apply_linear(input, self.weight, self.bias,
lora_params, layer_idx)
if self.gather_output:
output = allgather(output, self.mapping)
else:
output = self.apply_linear(input, self.weight, self.bias)
output = self.apply_linear(input, self.weight, self.bias,
lora_params, layer_idx)
return output

View File

@ -28,6 +28,10 @@ class MLP(nn.Module):
self.activation = activation
config = config or ModelConfig()
self.up_lora = LoraLayer(
[LoraModuleType.MLP_H_TO_4H],
[self.intermediate_size // config.mapping.tp_size])
self.up_proj = Linear(
self.hidden_size,
self.intermediate_size,
@ -38,8 +42,11 @@ class MLP(nn.Module):
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.VANILLA),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init)
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.up_lora)
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
self.down_proj = Linear(
self.intermediate_size,
self.hidden_size,
@ -48,12 +55,8 @@ class MLP(nn.Module):
mapping=config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init)
self.up_lora = LoraLayer([LoraModuleType.MLP_H_TO_4H],
[self.intermediate_size])
self.down_lora = LoraLayer([LoraModuleType.MLP_4H_TO_H],
[self.hidden_size])
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.down_lora)
def forward(
self,
@ -84,10 +87,8 @@ class MLP(nn.Module):
x_up = x_up + x_up_lora
x_act = self.activation(x_up)
x_down = self.down_proj(x_act)
x_down_lora = self.down_lora(x_act, lora_params, self.layer_idx)
if x_down_lora is not None:
x_down = x_down + x_down_lora
x_down = self.down_proj(x_act,
lora_params=lora_params,
layer_idx=self.layer_idx)
return x_down

View File

@ -92,8 +92,12 @@ class LoraLayer(torch.nn.Module):
self.output_hidden_sizes = output_hidden_sizes
assert len(lora_module_types) == len(output_hidden_sizes)
def forward(self, x, lora_params: Dict,
layer_idx: int) -> Optional[torch.Tensor]:
def forward(
self,
x,
lora_params: Dict,
layer_idx: int,
) -> Optional[torch.Tensor]:
if bool(lora_params):
lora_ranks = []
@ -147,7 +151,6 @@ class LoraLayer(torch.nn.Module):
],
dtype=x.dtype,
device=x.device))
# TODO smor should be by: dim=q_lora.rank() - 1
lora_output = torch.cat(lora_output, dim=-1)
return lora_output

View File

@ -10,7 +10,9 @@ import tensorrt_llm.bindings.executor as trtllm
from tensorrt_llm._utils import str_dtype_to_binding, torch_dtype_to_str
from tensorrt_llm.bindings.executor import DecodingMode, ExecutorConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraConfig, load_torch_hf_lora
from tensorrt_llm.lora_manager import (LoraConfig,
get_default_trtllm_modules_to_hf_modules,
load_torch_hf_lora)
from tensorrt_llm.mapping import Mapping
from ..speculative import get_num_spec_layers, get_spec_decoder
@ -380,7 +382,16 @@ def create_py_executor_instance(dist,
if lora_config is not None:
from tensorrt_llm.bindings import LoraModule
load_torch_hf_lora(lora_config)
if len(lora_config.lora_dir) == 1:
load_torch_hf_lora(lora_config)
else:
assert len(lora_config.lora_target_modules
) >= 1, "Expecting at least one lora target module"
if not bool(lora_config.trtllm_modules_to_hf_modules):
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
)
model_binding_config = model_engine.model.model_config.get_bindings_model_config(
)
lora_modules = LoraModule.create_lora_modules(
@ -406,9 +417,19 @@ def create_py_executor_instance(dist,
max_cpu_loras,
)
from tensorrt_llm.bindings import WorldConfig
world_config = WorldConfig(
tensor_parallelism=mapping.tp_size,
pipeline_parallelism=mapping.pp_size,
context_parallelism=mapping.cp_size,
rank=dist.mapping.rank,
gpus_per_node=dist.mapping.gpus_per_node,
)
peft_cache_manager = PeftCacheManager(
peft_cache_config=executor_config.peft_cache_config,
model_config=model_binding_config)
model_config=model_binding_config,
world_config=world_config,
)
resources["peft_cache_manager"] = peft_cache_manager
model_engine.set_lora_model_config(
lora_config.lora_target_modules,

View File

@ -30,6 +30,7 @@ KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEve
RequestList = list[LlmRequest]
PeftCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.PeftCacheManager
PeftCacheConfig = tensorrt_llm.bindings.executor.PeftCacheConfig
WorldConfig = tensorrt_llm.bindings.WorldConfig
def compute_page_count(token_count: int, tokens_per_page: int) -> int:
@ -715,8 +716,10 @@ class ResourceManager:
class PeftCacheManager(BaseResourceManager):
def __init__(self, peft_cache_config: PeftCacheConfig,
model_config: ModelConfig):
def __init__(self,
peft_cache_config: PeftCacheConfig,
model_config: ModelConfig,
world_config: WorldConfig | None = None):
import tensorrt_llm.bindings as _tb
peft_cache_manager_config = _tb.PeftCacheManagerConfig(
@ -735,8 +738,8 @@ class PeftCacheManager(BaseResourceManager):
lora_prefetch_dir=peft_cache_config.lora_prefetch_dir,
)
# TODO smor- currently set manually, change that
world_config = _tb.WorldConfig()
if world_config is None:
world_config = _tb.WorldConfig()
BufferManager = tensorrt_llm.bindings.internal.runtime.BufferManager
buffer_manager = BufferManager(torch.cuda.current_stream().cuda_stream,

View File

@ -634,7 +634,24 @@ class LLM:
if hasattr(
self.args, "backend"
) and self.args.backend == "pytorch" and self.args.lora_config is not None:
tokenizer_path = self.args.lora_config.lora_dir[0]
num_lora_dirs = len(self.args.lora_config.lora_dir)
if num_lora_dirs == 1:
tokenizer_path = self.args.lora_config.lora_dir[0]
try:
tokenizer = ModelLoader.load_hf_tokenizer(
tokenizer_path,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.tokenizer_mode != 'slow')
return tokenizer
except Exception:
tokenizer_path = self.args.model
elif num_lora_dirs > 1:
# TODO smor- currently not supported, need to determine which tokenizer to use, if possible
raise ValueError(
f"Expecting only a single lora dir, but got {num_lora_dirs}"
)
else:
tokenizer_path = self.args.model
else:
tokenizer_path = self.args.model
return ModelLoader.load_hf_tokenizer(

View File

@ -77,7 +77,7 @@ def iterate_hf_lora(iter_fn, lora_weights, hf_modules, component=None):
hf_module = m.group(3) + "." + module_name
if hf_module not in hf_modules:
hf_module = module_name
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported llist {hf_modules}"
assert hf_module in hf_modules, f"hf_module {hf_module} is not in supported list {hf_modules}"
is_lora_a_or_b = m.group(8) is not None
if is_lora_a_or_b:
@ -276,6 +276,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
lora_config.trtllm_modules_to_hf_modules = get_default_trtllm_modules_to_hf_modules(
)
assert len(lora_config.lora_dir) == 1, "Expecting only a single lora dir"
lora_loader = HfLoraLoader(lora_config.lora_dir)
if len(lora_config.lora_target_modules) == 0:

View File

@ -7,12 +7,10 @@ import torch
from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
from utils.llm_data import llm_models_root
from utils.util import getSMVersion, similar, skip_gpu_memory_less_than
from utils.util import getSMVersion
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
from tensorrt_llm._torch.llm import LLM
from tensorrt_llm._torch.metadata import KVCacheParams
from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.modeling_llama import LlamaForCausalLM
@ -20,11 +18,8 @@ from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import \
DecodingCUDAGraphRunner
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.executor.request import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.sampling_params import SamplingParams
LLAMA_3_1_8B_CONFIG = {
"architectures": ["LlamaForCausalLM"],
@ -368,33 +363,3 @@ class TestLlama(unittest.TestCase):
rtol=0.4)
kv_cache_manager.shutdown()
@skip_gpu_memory_less_than(40 * 2**30) # 40GB memory
def test_llama_lora(self) -> None:
lora_config = LoraConfig(lora_dir=[
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
],
max_lora_rank=64)
llm = LLM(
model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf",
lora_config=lora_config,
)
prompts = [
"今天天气很好,我到公园的时候,",
]
references = [
"发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的",
]
sampling_params = SamplingParams(max_tokens=20,
add_special_tokens=False)
lora_req = LoRARequest(
"task-0", 0,
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b")
lora_request = [lora_req]
outputs = llm.generate(prompts,
sampling_params,
lora_request=lora_request)
assert similar(outputs[0].outputs[0].text, references[0])

View File

@ -38,9 +38,6 @@ class TestResourceManager(unittest.TestCase):
"""
Setup the lora test data resources
"""
# TODO smor- this should be ported to a different place, ideally run once
# in a similar way to the cpp tests fixutres.
cpp_script_dir = os.path.join(cls.CPP_RESOURCES_DIR, "scripts")
generate_lora_data_args_tp1 = [
@ -258,7 +255,6 @@ class TestResourceManager(unittest.TestCase):
Returns:
tuple: (weights tensor, config tensor) formatted correctly for the C++ implementation.
"""
# TODO smor- change from custom path configuration to relative path
lora_weights = np.load(self.TP1_WEIGHTS_PATH).astype(np.float16)
lora_weights = np.expand_dims(lora_weights, axis=0)
lora_config = np.load(self.TP1_CONFIG_PATH)

View File

@ -1329,6 +1329,7 @@ def test_executor_lookahead_decoding_config():
def llama_v2_13b_lora_test_harness(**llm_kwargs):
# Shahar- perhaps disable build config
hf_model_dir = get_model_path("llama-models-v2/llama-v2-13b-hf")
hf_lora_dir = get_model_path("llama-models-v2/chinese-llama-2-lora-13b")

View File

@ -4,8 +4,14 @@ import pytest
from .test_llm import (global_kvcache_config,
tinyllama_guided_decoding_test_harness,
tinyllama_logits_processor_test_harness)
from tensorrt_llm.llmapi import KvCacheConfig
from .test_llm_pytorch import (llama_v2_13b_lora_test_harness,
llama_7b_multi_lora_test_harness)
# isort: on
global_kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.4)
@pytest.mark.gpu4
def test_tinyllama_guided_decoding_tp2pp2():
@ -31,3 +37,15 @@ def test_tinyllama_logits_processor_2gpu(tp_size: int, pp_size: int):
tinyllama_logits_processor_test_harness(backend="pytorch",
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size)
@pytest.mark.gpu2
def test_llama_v2_13b_lora_tp2():
llama_v2_13b_lora_test_harness(tensor_parallel_size=2,
kv_cache_config=global_kv_cache_config)
@pytest.mark.gpu2
def test_llama_7b_multi_lora_tp2():
llama_7b_multi_lora_test_harness(tensor_parallel_size=2,
kv_cache_config=global_kv_cache_config)

View File

@ -4,14 +4,17 @@ from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
from tensorrt_llm.sampling_params import SamplingParams
# isort: off
from .test_llm import (get_model_path, global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness,
llm_get_stats_test_harness, prompts,
run_llm_abort_request,
run_llm_with_postprocess_parallel_and_result_handler,
tinyllama_guided_decoding_test_harness,
tinyllama_logits_processor_test_harness)
from utils.util import force_ampere
from .test_llm import (
get_model_path, global_kvcache_config, llama_model_path,
llm_get_stats_async_test_harness, llm_get_stats_test_harness, prompts,
run_llm_abort_request, run_llm_with_postprocess_parallel_and_result_handler,
tinyllama_guided_decoding_test_harness,
tinyllama_logits_processor_test_harness, llama_7b_multi_lora_test_harness)
from utils.util import force_ampere, similar, skip_gpu_memory_less_than_40gb
from utils.llm_data import llm_models_root
from tensorrt_llm.lora_manager import LoraConfig
from tensorrt_llm.executor.request import LoRARequest
# isort: on
@ -89,3 +92,85 @@ def test_llm_with_postprocess_parallel_and_result_handler(streaming):
run_llm_with_postprocess_parallel_and_result_handler(streaming,
"pytorch",
tp_size=1)
def llama_v2_13b_lora_test_harness(**llm_kwargs) -> None:
from tensorrt_llm._torch.llm import LLM
lora_config = LoraConfig(lora_dir=[
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
],
max_lora_rank=64)
llm = LLM(model=f"{llm_models_root()}/llama-models-v2/llama-v2-13b-hf",
lora_config=lora_config,
**llm_kwargs)
prompts = [
"今天天气很好,我到公园的时候,",
]
references = [
"发现公园里到处都是人,有的在跑步,有的在打羽毛球,还有的",
]
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
lora_req = LoRARequest(
"task-0", 0,
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b")
lora_request = [lora_req]
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
assert similar(outputs[0].outputs[0].text, references[0])
def llama_7b_multi_lora_test_harness(**llm_kwargs) -> None:
from tensorrt_llm._torch.llm import LLM
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
hf_lora_dir1 = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
hf_lora_dir2 = f"{llm_models_root()}/llama-models/Japanese-Alpaca-LoRA-7b-v0"
# For LoRA checkpoints without finetuned embedding and lm_head, we can either:
# (1) specify lora_target_modules, or
# (2) provide a lora_dir to infer the lora_target_modules.
lora_config = LoraConfig(lora_target_modules=['attn_q', 'attn_k', 'attn_v'],
max_lora_rank=8)
llm = LLM(hf_model_dir,
fast_build=True,
lora_config=lora_config,
**llm_kwargs)
prompts = [
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"美国的首都在哪里? \n答案:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
"アメリカ合衆国の首都はどこですか? \n答え:",
]
references = [
"沃尔玛\n\n## 新闻\n\n* ",
"美国的首都是华盛顿。\n\n美国的",
"纽约\n\n### カンファレンスの",
"Washington, D.C.\nWashington, D.C. is the capital of the United",
"华盛顿。\n\n英国の首都是什",
"ワシントン\nQ1. アメリカ合衆国",
]
lora_req1 = LoRARequest("luotuo", 1, hf_lora_dir1)
lora_req2 = LoRARequest("Japanese", 2, hf_lora_dir2)
sampling_params = SamplingParams(max_tokens=20)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=[None, lora_req1, lora_req2, None, lora_req1, lora_req2])
for output, ref in zip(outputs, references):
assert similar(output.outputs[0].text, ref)
@skip_gpu_memory_less_than_40gb
def test_llama_v2_13b_lora():
llama_v2_13b_lora_test_harness()
@skip_gpu_memory_less_than_40gb
def test_llama_7b_multi_lora():
llama_7b_multi_lora_test_harness()