add passing E2E LoRA flow (#3788)

add passing E2E LoRA flow (#3788)

Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
shaharmor98 2025-04-23 18:38:06 +03:00 committed by GitHub
parent a51b3cf7a6
commit 49262a62a5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 115 additions and 48 deletions

View File

@ -0,0 +1,38 @@
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm.executor import LoRARequest
from tensorrt_llm.lora_manager import LoraConfig
def main():
lora_config = LoraConfig(lora_dir=[
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
],
max_lora_rank=64)
llm = LLM(
model=
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/llama-v2-13b-hf",
lora_config=lora_config,
)
prompts = [
"今天天气很好,我到公园的时候,",
]
sampling_params = SamplingParams(max_tokens=20, add_special_tokens=False)
lora_req_2 = LoRARequest(
"task-0", 0,
"/home/scratch.trt_llm_data/llm-models/llama-models-v2/chinese-llama-2-lora-13b"
)
lora_request = [lora_req_2]
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
if __name__ == '__main__':
main()

View File

@ -757,7 +757,9 @@ class LlamaModel(DecoderModel):
if hasattr(model_config,
'lora_config') and model_config.lora_config is not None:
self.embed_tokens.weight.value = weight.to(self.embed_tokens.dtype)
with torch.no_grad():
x = weight.to(self.embed_tokens.dtype)
self.embed_tokens.weight.data.copy_(x)
self.layers = nn.ModuleList([
LlamaDecoderLayer(

View File

@ -377,9 +377,9 @@ class DecoderModelForCausalLM(nn.Module,
if hasattr(config,
'lora_config') and config.lora_config is not None:
# TODO smor- figure out if it sticks
self.lm_head.weight.value = weight.to(
self.lm_head.dtype).cuda()
with torch.no_grad():
x = weight.to(self.lm_head.dtype).cuda()
self.lm_head.weight.data.copy_(x)
# use embedding weights in lm_head if tie word embedding is enabled
if config.pretrained_config.tie_word_embeddings and not isinstance(

View File

@ -201,11 +201,11 @@ class Attention(nn.Module):
out_scale=out_scale,
attention_mask=attention_mask,
mrope_config=mrope_config)
hidden_states = attn_output
attn_output = self.o_proj(attn_output,
all_reduce_params=all_reduce_params)
if bool(lora_params):
attn_lora_output = self.o_lora(attn_output, lora_params,
attn_lora_output = self.o_lora(hidden_states, lora_params,
self.layer_idx)
if attn_lora_output is not None:
attn_output = attn_output + attn_lora_output

View File

@ -94,8 +94,8 @@ class LoraLayer(torch.nn.Module):
def forward(self, x, lora_params: Dict,
layer_idx: int) -> Optional[torch.Tensor]:
if bool(lora_params):
if bool(lora_params):
lora_ranks = []
lora_weight_pointers = []
active_lora_module_ids = []
@ -147,7 +147,7 @@ class LoraLayer(torch.nn.Module):
],
dtype=x.dtype,
device=x.device))
# TODO smor should be by: dim=q_lora.rank() - 1
lora_output = torch.cat(lora_output, dim=-1)
return lora_output

View File

@ -773,8 +773,6 @@ class PeftCacheManager(BaseResourceManager):
buffer_manager=buffer_manager)
def add_request_peft(self, request: LlmRequest):
# TODO smor- a helper function to add a request to the peft cache manager.
# Cosnider replacing in favor of prepare_resources
self.impl.add_request_peft(request, True)
def ensure_batch(self,

View File

@ -595,8 +595,17 @@ class LLM:
if self.runtime_context is not None:
return self.runtime_context.tokenizer
# TODO smor- need to look more on this
# what should be chose as the tokenizer? the adapter or the base model?
# what happens if we have multiple adapters?
if hasattr(
self.args, "backend"
) and self.args.backend == "pytorch" and self.args.lora_config is not None:
tokenizer_path = self.args.lora_config.lora_dir[0]
else:
tokenizer_path = self.args.model
return ModelLoader.load_hf_tokenizer(
self.args.model,
tokenizer_path,
trust_remote_code=self.args.trust_remote_code,
use_fast=self.args.tokenizer_mode != 'slow')

View File

@ -1,3 +1,4 @@
import io
import json
import re
import tarfile
@ -34,6 +35,10 @@ def get_all_nemo_lora_weights(lora_weights):
else:
continue
m = layer_pattern.match(key)
if m is None:
raise KeyError(
f"Failed to extract layer index from key {key} using pattern {layer_pattern.pattern}"
)
layer_idx = int(m.group(1))
layer_weights[layer_idx][inout] = weights
else:
@ -189,7 +194,10 @@ class HfLoraLoader:
with open(f"{lora_dir}/adapter_config.json") as f:
adapter_config = json.load(f)
lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model"))
model_path = get_model_path(lora_dir, "adapter_model")
if model_path is None:
raise ValueError(f"adapter_model file does not exist in {lora_dir}")
lora_weight = load_state_dict(model_path)
self.lora_weight = lora_weight
if adapter_config["modules_to_save"] is not None:
if "lm_head" in adapter_config["modules_to_save"]:
@ -288,7 +296,7 @@ def load_torch_hf_lora(lora_config: LoraConfig):
def load_hf_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Dict[str, str] = None,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules(
)
@ -374,7 +382,7 @@ def load_hf_lora(
def use_lora(
model,
lora_config: LoraConfig,
trtllm_modules_to_hf_modules: Dict[str, str] = None,
trtllm_modules_to_hf_modules: Optional[Dict[str, str]] = None,
):
if lora_config.lora_ckpt_source == "nemo":
load_nemo_lora(model, lora_config)
@ -388,17 +396,27 @@ def use_lora(
def unpack_nemo_weights(nemo_archive_path):
with tarfile.open(nemo_archive_path) as tar:
try:
model_weights = tar.extractfile("model_weights.ckpt")
model_config = tar.extractfile("model_config.yaml")
model_weights_file = tar.extractfile("model_weights.ckpt")
model_config_file = tar.extractfile("model_config.yaml")
except KeyError:
try:
model_weights = tar.extractfile("./model_weights.ckpt")
model_config = tar.extractfile("./model_config.yaml")
model_weights_file = tar.extractfile("./model_weights.ckpt")
model_config_file = tar.extractfile("./model_config.yaml")
except KeyError:
err_str = "Both model_weights paths not found in the tar archive."
raise Exception(err_str)
return yaml.safe_load(model_config), torch.load(
model_weights, map_location=torch.device("cpu"))
if model_weights_file is None or model_config_file is None:
raise Exception("Could not extract model weights or config files")
model_config_content = model_config_file.read()
model_config_dict = yaml.safe_load(model_config_content)
model_weights_bytes = model_weights_file.read()
model_weights_dict = torch.load(io.BytesIO(model_weights_bytes),
map_location=torch.device("cpu"))
return model_config_dict, model_weights_dict
class LoraManager(object):
@ -507,7 +525,7 @@ class LoraManager(object):
def load_from_nemo(self,
model_files: List[str],
model_config: 'ModelConfig',
model_config: Union['ModelConfig', LoraModelConfig],
runtime_mapping: Optional[Mapping] = None,
uids: Optional[List[str]] = None):
if runtime_mapping is None:
@ -533,9 +551,11 @@ class LoraManager(object):
def load_from_model_file(uid, model_file):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = []
self._cpp_lora_weights[uid] = [
] # Will be converted to tensor later
if uid not in self._cpp_lora_config:
self._cpp_lora_config[uid] = []
self._cpp_lora_config[uid] = [
] # Will be converted to tensor later
_, nemo_weights = unpack_nemo_weights(model_file)
all_lora_weights = get_all_nemo_lora_weights(nemo_weights)
@ -704,12 +724,17 @@ class LoraManager(object):
def load_from_model_dir(uid, model_dir, hf_config):
if uid not in self._cpp_lora_weights:
self._cpp_lora_weights[uid] = []
self._cpp_lora_weights[uid] = [
] # Will be converted to tensor later
if uid not in self._cpp_lora_config:
self._cpp_lora_config[uid] = []
self._cpp_lora_config[uid] = [
] # Will be converted to tensor later
lora_model = load_state_dict(
get_model_path(model_dir, "adapter_model"))
if lora_model is None:
raise ValueError(
f"Failed to load adapter_model from {model_dir}")
lora_model = preprocess_lora_weights(lora_model)
all_weights = get_all_hf_lora_weights(lora_model, hf_modules,
component)
@ -786,7 +811,7 @@ class LoraManager(object):
t_out = torch.split(t_out,
t_out.shape[dim] // tp_size,
dim=dim)[tp_rank].contiguous()
if dim == 0 and is_dora:
if dim == 0 and is_dora and t_mag is not None:
t_mag = torch.split(t_mag,
t_mag.shape[0] // tp_size,
dim=0)[tp_rank].contiguous()
@ -796,7 +821,7 @@ class LoraManager(object):
t_in = t_in.cuda().contiguous()
t_out = t_out.cuda().contiguous()
if is_dora:
if is_dora and t_mag is not None:
t_mag = t_mag.cuda().contiguous()
if rs_lora:
@ -807,7 +832,7 @@ class LoraManager(object):
t_out = t_out * scale
t_in = t_in.to(str_dtype_to_torch(model_config.dtype))
t_out = t_out.to(str_dtype_to_torch(model_config.dtype))
if is_dora:
if is_dora and t_mag is not None:
t_mag = t_mag.to(str_dtype_to_torch(model_config.dtype))
self._lora_uid_to_low_ranks[uid][layer_idx][
@ -816,20 +841,26 @@ class LoraManager(object):
lora_module] = [
t_in.data_ptr(),
t_out.data_ptr(),
t_mag.data_ptr() if is_dora else 0
t_mag.data_ptr() if
(is_dora and t_mag is not None) else 0
]
# prevent torch free this buffer
self._lora_weights.append(t_in)
self._lora_weights.append(t_out)
if is_dora:
if is_dora and t_mag is not None:
self._lora_weights.append(t_mag)
t_in_cpu = t_in.flatten().cpu()
t_out_cpu = t_out.flatten().cpu()
weights_to_concat = [t_in_cpu, t_out_cpu]
if is_dora and t_mag is not None:
t_mag_cpu = t_mag.flatten().cpu()
weights_to_concat.append(t_mag_cpu)
self._cpp_lora_weights[uid].append(
torch.concatenate(
[t_in.flatten().cpu(),
t_out.flatten().cpu()] +
([t_mag.flatten().cpu()] if is_dora else [])))
torch.cat(weights_to_concat))
self._cpp_lora_config[uid].append(
torch.tensor([
self.LORA_MODULE_IDS[lora_module], layer_idx,

View File

@ -8,7 +8,7 @@ from parameterized import parameterized
from transformers import LlamaConfig
from transformers import LlamaForCausalLM as HFLlamaForCausalLM
from utils.llm_data import llm_models_root
from utils.util import getSMVersion, skip_gpu_memory_less_than
from utils.util import getSMVersion, similar, skip_gpu_memory_less_than
import tensorrt_llm
from tensorrt_llm._torch.attention_backend.utils import get_attention_backend
@ -369,13 +369,8 @@ class TestLlama(unittest.TestCase):
kv_cache_manager.shutdown()
@skip_gpu_memory_less_than(40 * 1024 * 1024 *
1024) # 40gb, same as test_llmapi
@skip_gpu_memory_less_than(40 * 2**30) # 40GB memory
def test_llama_lora(self) -> None:
# TODO smor- this test is running but correctness is not guaranteed.
# The following PR will ensure correctness
# We might want to change this test location elsewhere
lora_config = LoraConfig(lora_dir=[
f"{llm_models_root()}/llama-models-v2/chinese-llama-2-lora-13b"
],
@ -402,10 +397,4 @@ class TestLlama(unittest.TestCase):
sampling_params,
lora_request=lora_request)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(
f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}\nExpected: {references[0]!r}"
)
assert similar(outputs[0].outputs[0].text, references[0])