mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[https://nvbugs/5503529][fix] Change test_llmapi_example_multilora to get adapters path from cmd line to avoid downloading from HF (#7740)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
This commit is contained in:
parent
6eef19297f
commit
750d15bfaa
@ -1,6 +1,10 @@
|
|||||||
### :section Customization
|
### :section Customization
|
||||||
### :title Generate text with multiple LoRA adapters
|
### :title Generate text with multiple LoRA adapters
|
||||||
### :order 5
|
### :order 5
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
from tensorrt_llm import LLM
|
from tensorrt_llm import LLM
|
||||||
@ -8,17 +12,24 @@ from tensorrt_llm.executor import LoRARequest
|
|||||||
from tensorrt_llm.lora_helper import LoraConfig
|
from tensorrt_llm.lora_helper import LoraConfig
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main(chatbot_lora_dir: Optional[str], mental_health_lora_dir: Optional[str],
|
||||||
|
tarot_lora_dir: Optional[str]):
|
||||||
|
|
||||||
# Download the LoRA adapters from huggingface hub.
|
# Download the LoRA adapters from huggingface hub, if not provided via command line args.
|
||||||
lora_dir1 = snapshot_download(repo_id="snshrivas10/sft-tiny-chatbot")
|
if chatbot_lora_dir is None:
|
||||||
lora_dir2 = snapshot_download(
|
chatbot_lora_dir = snapshot_download(
|
||||||
repo_id="givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
|
repo_id="snshrivas10/sft-tiny-chatbot")
|
||||||
lora_dir3 = snapshot_download(repo_id="barissglc/tinyllama-tarot-v1")
|
if mental_health_lora_dir is None:
|
||||||
|
mental_health_lora_dir = snapshot_download(
|
||||||
|
repo_id=
|
||||||
|
"givyboy/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational")
|
||||||
|
if tarot_lora_dir is None:
|
||||||
|
tarot_lora_dir = snapshot_download(
|
||||||
|
repo_id="barissglc/tinyllama-tarot-v1")
|
||||||
|
|
||||||
# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
|
# Currently, we need to pass at least one lora_dir to LLM constructor via build_config.lora_config.
|
||||||
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
|
# This is necessary because it requires some configuration in the lora_dir to build the engine with LoRA support.
|
||||||
lora_config = LoraConfig(lora_dir=[lora_dir1],
|
lora_config = LoraConfig(lora_dir=[chatbot_lora_dir],
|
||||||
max_lora_rank=64,
|
max_lora_rank=64,
|
||||||
max_loras=3,
|
max_loras=3,
|
||||||
max_cpu_loras=3)
|
max_cpu_loras=3)
|
||||||
@ -39,10 +50,11 @@ def main():
|
|||||||
for output in llm.generate(prompts,
|
for output in llm.generate(prompts,
|
||||||
lora_request=[
|
lora_request=[
|
||||||
None,
|
None,
|
||||||
LoRARequest("chatbot", 1, lora_dir1), None,
|
LoRARequest("chatbot", 1, chatbot_lora_dir),
|
||||||
LoRARequest("mental-health", 2, lora_dir2),
|
|
||||||
None,
|
None,
|
||||||
LoRARequest("tarot", 3, lora_dir3)
|
LoRARequest("mental-health", 2,
|
||||||
|
mental_health_lora_dir), None,
|
||||||
|
LoRARequest("tarot", 3, tarot_lora_dir)
|
||||||
]):
|
]):
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
generated_text = output.outputs[0].text
|
generated_text = output.outputs[0].text
|
||||||
@ -58,4 +70,20 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate text with multiple LoRA adapters")
|
||||||
|
parser.add_argument('--chatbot_lora_dir',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to the chatbot LoRA directory')
|
||||||
|
parser.add_argument('--mental_health_lora_dir',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to the mental health LoRA directory')
|
||||||
|
parser.add_argument('--tarot_lora_dir',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='Path to the tarot LoRA directory')
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args.chatbot_lora_dir, args.mental_health_lora_dir,
|
||||||
|
args.tarot_lora_dir)
|
||||||
|
|||||||
@ -110,7 +110,16 @@ def test_llmapi_example_inference_async_streaming(llm_root, engine_dir,
|
|||||||
|
|
||||||
|
|
||||||
def test_llmapi_example_multilora(llm_root, engine_dir, llm_venv):
|
def test_llmapi_example_multilora(llm_root, engine_dir, llm_venv):
|
||||||
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py")
|
cmd_line_args = [
|
||||||
|
"--chatbot_lora_dir",
|
||||||
|
f"{llm_models_root()}/llama-models-v2/sft-tiny-chatbot",
|
||||||
|
"--mental_health_lora_dir",
|
||||||
|
f"{llm_models_root()}/llama-models-v2/TinyLlama-1.1B-Chat-v1.0-mental-health-conversational",
|
||||||
|
"--tarot_lora_dir",
|
||||||
|
f"{llm_models_root()}/llama-models-v2/tinyllama-tarot-v1"
|
||||||
|
]
|
||||||
|
_run_llmapi_example(llm_root, engine_dir, llm_venv, "llm_multilora.py",
|
||||||
|
*cmd_line_args)
|
||||||
|
|
||||||
|
|
||||||
def test_llmapi_example_guided_decoding(llm_root, engine_dir, llm_venv):
|
def test_llmapi_example_guided_decoding(llm_root, engine_dir, llm_venv):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user