mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
add passing E2E LoRA flow (#3788)
add passing E2E LoRA flow (#3788) Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
parent
a51b3cf7a6
commit
49262a62a5
38
examples/pytorch/quickstart_lora.py
Normal file
38
examples/pytorch/quickstart_lora.py
Normal file
@ -0,0 +1,38 @@
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm.executor import LoRARequest
|
||||
from tensorrt_llm.lora_manager import LoraConfig
|
||||
|
||||
|
||||
def main():
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
],
|
||||
max_lora_rank=64)
|
||||
llm = LLM(
|
||||
model=
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf",
|
||||
lora_config=lora_config,
|
||||
)
|
||||
prompts = [
|
||||
"今天天气很好,我到公园的时候,",
|
||||
]
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
|
||||
lora_req_2 = LoRARequest(
|
||||
"task-0", 0,
|
||||
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
)
|
||||
lora_request = [lora_req_2]
|
||||
|
||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -757,7 +757,9 @@ class LlamaModel(DecoderModel):
|
||||
|
||||
if hasattr(model_config,
|
||||
'lora_config') and model_config.lora_config is not None:
|
||||
self.embed_tokens.weight.value = weight.to(self.embed_tokens.dtype)
|
||||
with torch.no_grad():
|
||||
x = weight.to(self.embed_tokens.dtype)
|
||||
self.embed_tokens.weight.data.copy_(x)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(
|
||||
|
||||
@ -377,9 +377,9 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
|
||||
if hasattr(config,
|
||||
'lora_config') and config.lora_config is not None:
|
||||
# TODO smor- figure out if it sticks
|
||||
self.lm_head.weight.value = weight.to(
|
||||
self.lm_head.dtype).cuda()
|
||||
with torch.no_grad():
|
||||
x = weight.to(self.lm_head.dtype).cuda()
|
||||
self.lm_head.weight.data.copy_(x)
|
||||
|
||||
# use embedding weights in lm_head if tie word embedding is enabled
|
||||
if config.pretrained_config.tie_word_embeddings and not isinstance(
|
||||
|
||||
@ -201,11 +201,11 @@ class Attention(nn.Module):
|
||||
out_scale=out_scale,
|
||||
attention_mask=attention_mask,
|
||||
mrope_config=mrope_config)
|
||||
|
||||
hidden_states = attn_output
|
||||
attn_output = self.o_proj(attn_output,
|
||||
all_reduce_params=all_reduce_params)
|
||||
if bool(lora_params):
|
||||
attn_lora_output = self.o_lora(attn_output, lora_params,
|
||||
attn_lora_output = self.o_lora(hidden_states, lora_params,
|
||||
self.layer_idx)
|
||||
if attn_lora_output is not None:
|
||||
attn_output = attn_output + attn_lora_output
|
||||
|
||||
@ -94,8 +94,8 @@ class LoraLayer(torch.nn.Module):
|
||||
|
||||
def forward(self, x, lora_params: Dict,
|
||||
layer_idx: int) -> Optional[torch.Tensor]:
|
||||
if bool(lora_params):
|
||||
|
||||
if bool(lora_params):
|
||||
lora_ranks = []
|
||||
lora_weight_pointers = []
|
||||
active_lora_module_ids = []
|
||||
@ -147,7 +147,7 @@ class LoraLayer(torch.nn.Module):
|
||||
],
|
||||
dtype=x.dtype,
|
||||
device=x.device))
|
||||
|
||||
# TODO smor should be by: dim=q_lora.rank() - 1
|
||||
lora_output = torch.cat(lora_output, dim=-1)
|
||||
return lora_output
|
||||
|
||||
|
||||
@ -773,8 +773,6 @@ class PeftCacheManager(BaseResourceManager):
|
||||
buffer_manager=buffer_manager)
|
||||
|
||||
def add_request_peft(self, request: LlmRequest):
|
||||
# TODO smor- a helper function to add a request to the peft cache manager.
|
||||
# Cosnider replacing in favor of prepare_resources
|
||||
self.impl.add_request_peft(request, True)
|
||||
|
||||
def ensure_batch(self,
|
||||
|
||||
@ -595,8 +595,17 @@ class LLM:
|
||||
if self.runtime_context is not None:
|
||||
return self.runtime_context.tokenizer
|
||||
|
||||
# TODO smor- need to look more on this
|
||||
# what should be chose as the tokenizer? the adapter or the base model?
|
||||
# what happens if we have multiple adapters?
|
||||
if hasattr(
|
||||
self.args, "backend"
|
||||
) and self.args.backend == "pytorch" and self.args.lora_config is not None:
|
||||
tokenizer_path = self.args.lora_config.lora_dir[0]
|
||||
else:
|
||||
tokenizer_path = self.args.model
|
||||
return ModelLoader.load_hf_tokenizer(
|
||||
self.args.model,
|
||||
tokenizer_path,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
use_fast=self.args.tokenizer_mode != 'slow')
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import tarfile
|
||||
@ -34,6 +35,10 @@ def get_all_nemo_lora_weights(lora_weights):
|
||||
else:
|
||||
continue
|
||||
m = layer_pattern.match(key)
|
||||
if m is None:
|
||||
raise KeyError(
|
||||
f"Failed to extract layer index from key {key} using pattern {layer_pattern.pattern}"
|
||||
)
|
||||
layer_idx = int(m.group(1))
|
||||
layer_weights[layer_idx][inout] = weights
|
||||
else:
|
||||
@ -189,7 +194,10 @@ class HfLoraLoader:
|
||||
with open(f"{lora_dir}/adapter_config.json") as f:
|
||||
adapter_config = json.load(f)
|
||||
|
||||
lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model"))
|
||||
model_path = get_model_path(lora_dir, "adapter_model")
|
||||
if model_path is None:
|
||||
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
|
||||
lora_weight = load_state_dict(model_path)
|
||||
self.lora_weight = lora_weight
|
||||
if adapter_config["modules_to_save"] is not None:
|
||||
if "lm_head" in adapter_config["modules_to_save"]:
|
||||
@ -288,7 +296,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
|
||||
def load_hf_lora(
|
||||
model,
|
||||
lora_config: LoraConfig,
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = None,
|
||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules(
|
||||
)
|
||||
@ -374,7 +382,7 @@ def load_hf_lora(
|
||||
def use_lora(
|
||||
model,
|
||||
lora_config: LoraConfig,
|
||||
trtllm_modules_to_hf_modules: Dict[str, str] = None,
|
||||
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
if lora_config.lora_ckpt_source == "nemo":
|
||||
load_nemo_lora(model, lora_config)
|
||||
@ -388,17 +396,27 @@ def use_lora(
|
||||
def unpack_nemo_weights(nemo_archive_path):
|
||||
with tarfile.open(nemo_archive_path) as tar:
|
||||
try:
|
||||
model_weights = tar.extractfile("model_weights.ckpt")
|
||||
model_config = tar.extractfile("model_config.yaml")
|
||||
model_weights_file = tar.extractfile("model_weights.ckpt")
|
||||
model_config_file = tar.extractfile("model_config.yaml")
|
||||
except KeyError:
|
||||
try:
|
||||
model_weights = tar.extractfile("./model_weights.ckpt")
|
||||
model_config = tar.extractfile("./model_config.yaml")
|
||||
model_weights_file = tar.extractfile("./model_weights.ckpt")
|
||||
model_config_file = tar.extractfile("./model_config.yaml")
|
||||
except KeyError:
|
||||
err_str = "Both model_weights paths not found in the tar archive."
|
||||
raise Exception(err_str)
|
||||
return yaml.safe_load(model_config), torch.load(
|
||||
model_weights, map_location=torch.device("cpu"))
|
||||
|
||||
if model_weights_file is None or model_config_file is None:
|
||||
raise Exception("Could not extract model weights or config files")
|
||||
|
||||
model_config_content = model_config_file.read()
|
||||
model_config_dict = yaml.safe_load(model_config_content)
|
||||
|
||||
model_weights_bytes = model_weights_file.read()
|
||||
model_weights_dict = torch.load(io.BytesIO(model_weights_bytes),
|
||||
map_location=torch.device("cpu"))
|
||||
|
||||
return model_config_dict, model_weights_dict
|
||||
|
||||
|
||||
class LoraManager(object):
|
||||
@ -507,7 +525,7 @@ class LoraManager(object):
|
||||
|
||||
def load_from_nemo(self,
|
||||
model_files: List[str],
|
||||
model_config: 'ModelConfig',
|
||||
model_config: Union['ModelConfig', LoraModelConfig],
|
||||
runtime_mapping: Optional[Mapping] = None,
|
||||
uids: Optional[List[str]] = None):
|
||||
if runtime_mapping is None:
|
||||
@ -533,9 +551,11 @@ class LoraManager(object):
|
||||
|
||||
def load_from_model_file(uid, model_file):
|
||||
if uid not in self._cpp_lora_weights:
|
||||
self._cpp_lora_weights[uid] = []
|
||||
self._cpp_lora_weights[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
if uid not in self._cpp_lora_config:
|
||||
self._cpp_lora_config[uid] = []
|
||||
self._cpp_lora_config[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
|
||||
_, nemo_weights = unpack_nemo_weights(model_file)
|
||||
all_lora_weights = get_all_nemo_lora_weights(nemo_weights)
|
||||
@ -704,12 +724,17 @@ class LoraManager(object):
|
||||
|
||||
def load_from_model_dir(uid, model_dir, hf_config):
|
||||
if uid not in self._cpp_lora_weights:
|
||||
self._cpp_lora_weights[uid] = []
|
||||
self._cpp_lora_weights[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
if uid not in self._cpp_lora_config:
|
||||
self._cpp_lora_config[uid] = []
|
||||
self._cpp_lora_config[uid] = [
|
||||
] # Will be converted to tensor later
|
||||
|
||||
lora_model = load_state_dict(
|
||||
get_model_path(model_dir, "adapter_model"))
|
||||
if lora_model is None:
|
||||
raise ValueError(
|
||||
f"Failed to load adapter_model from {model_dir}")
|
||||
lora_model = preprocess_lora_weights(lora_model)
|
||||
all_weights = get_all_hf_lora_weights(lora_model, hf_modules,
|
||||
component)
|
||||
@ -786,7 +811,7 @@ class LoraManager(object):
|
||||
t_out = torch.split(t_out,
|
||||
t_out.shape[dim] // tp_size,
|
||||
dim=dim)[tp_rank].contiguous()
|
||||
if dim == 0 and is_dora:
|
||||
if dim == 0 and is_dora and t_mag is not None:
|
||||
t_mag = torch.split(t_mag,
|
||||
t_mag.shape[0] // tp_size,
|
||||
dim=0)[tp_rank].contiguous()
|
||||
@ -796,7 +821,7 @@ class LoraManager(object):
|
||||
|
||||
t_in = t_in.cuda().contiguous()
|
||||
t_out = t_out.cuda().contiguous()
|
||||
if is_dora:
|
||||
if is_dora and t_mag is not None:
|
||||
t_mag = t_mag.cuda().contiguous()
|
||||
|
||||
if rs_lora:
|
||||
@ -807,7 +832,7 @@ class LoraManager(object):
|
||||
t_out = t_out * scale
|
||||
t_in = t_in.to(str_dtype_to_torch(model_config.dtype))
|
||||
t_out = t_out.to(str_dtype_to_torch(model_config.dtype))
|
||||
if is_dora:
|
||||
if is_dora and t_mag is not None:
|
||||
t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype))
|
||||
|
||||
self._lora_uid_to_low_ranks[uid][layer_idx][
|
||||
@ -816,20 +841,26 @@ class LoraManager(object):
|
||||
lora_module] = [
|
||||
t_in.data_ptr(),
|
||||
t_out.data_ptr(),
|
||||
t_mag.data_ptr() if is_dora else 0
|
||||
t_mag.data_ptr() if
|
||||
(is_dora and t_mag is not None) else 0
|
||||
]
|
||||
|
||||
# prevent torch free this buffer
|
||||
self._lora_weights.append(t_in)
|
||||
self._lora_weights.append(t_out)
|
||||
if is_dora:
|
||||
if is_dora and t_mag is not None:
|
||||
self._lora_weights.append(t_mag)
|
||||
|
||||
t_in_cpu = t_in.flatten().cpu()
|
||||
t_out_cpu = t_out.flatten().cpu()
|
||||
weights_to_concat = [t_in_cpu, t_out_cpu]
|
||||
|
||||
if is_dora and t_mag is not None:
|
||||
t_mag_cpu = t_mag.flatten().cpu()
|
||||
weights_to_concat.append(t_mag_cpu)
|
||||
|
||||
self._cpp_lora_weights[uid].append(
|
||||
torch.concatenate(
|
||||
[t_in.flatten().cpu(),
|
||||
t_out.flatten().cpu()] +
|
||||
([t_mag.flatten().cpu()] if is_dora else [])))
|
||||
torch.cat(weights_to_concat))
|
||||
self._cpp_lora_config[uid].append(
|
||||
torch.tensor([
|
||||
self.LORA_MODULE_IDS[lora_module], layer_idx,
|
||||
|
||||
@ -8,7 +8,7 @@ from parameterized import parameterized
|
||||
from transformers import LlamaConfig
|
||||
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
|
||||
from utils.llm_data import llm_models_root
|
||||
from utils.util import getSMVersion, skip_gpu_memory_less_than
|
||||
from utils.util import getSMVersion, similar, skip_gpu_memory_less_than
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
|
||||
@ -369,13 +369,8 @@ class TestLlama(unittest.TestCase):
|
||||
|
||||
kv_cache_manager.shutdown()
|
||||
|
||||
@skip_gpu_memory_less_than(40 * 1024 * 1024 *
|
||||
1024) # 40gb, same as test_llmapi
|
||||
@skip_gpu_memory_less_than(40 * 2**30) # 40GB memory
|
||||
def test_llama_lora(self) -> None:
|
||||
# TODO smor- this test is running but correctness is not guaranteed.
|
||||
# The following PR will ensure correctness
|
||||
# We might want to change this test location elsewhere
|
||||
|
||||
lora_config = LoraConfig(lora_dir=[
|
||||
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
|
||||
],
|
||||
@ -402,10 +397,4 @@ class TestLlama(unittest.TestCase):
|
||||
sampling_params,
|
||||
lora_request=lora_request)
|
||||
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(
|
||||
f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nExpected: {references[0]!r}"
|
||||
)
|
||||
assert similar(outputs[0].outputs[0].text, references[0])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user