mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
328 lines
13 KiB
Python
328 lines
13 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import time
|
|
|
|
from quickstart_advanced import add_llm_args, setup_llm
|
|
|
|
from tensorrt_llm.inputs import default_multimodal_input_loader
|
|
from tensorrt_llm.inputs.registry import MULTIMODAL_PLACEHOLDER_REGISTRY
|
|
from tensorrt_llm.tools.importlib_utils import import_custom_module_from_dir
|
|
|
|
example_medias_and_prompts = {
|
|
"image": {
|
|
"media": [
|
|
"https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/seashore.png",
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
|
],
|
|
"prompt": [
|
|
"Describe the natural environment in the image.",
|
|
"Describe the object and the weather condition in the image.",
|
|
"Describe the traffic condition on the road in the image.",
|
|
]
|
|
},
|
|
"video": {
|
|
"media": [
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4",
|
|
"https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4",
|
|
],
|
|
"prompt": [
|
|
"Tell me what you see in the video briefly.",
|
|
"Describe the scene in the video briefly.",
|
|
]
|
|
},
|
|
"audio": {
|
|
"media": [
|
|
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_the_traffic_sign_in_the_image.wav",
|
|
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav",
|
|
],
|
|
"prompt": [
|
|
"Transcribe the audio clip into text, please don't add other text.",
|
|
"Transcribe the audio clip into text, please don't add other text.",
|
|
]
|
|
},
|
|
"image_audio": {
|
|
"media": [
|
|
[
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
|
|
],
|
|
[
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/microsoft/Phi-4-multimodal-instruct/resolve/main/examples/what_is_shown_in_this_image.wav"
|
|
],
|
|
],
|
|
"prompt": [
|
|
"Describe the scene in the image briefly.",
|
|
"",
|
|
]
|
|
},
|
|
"multiple_image": {
|
|
"media": [
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png",
|
|
"https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg",
|
|
],
|
|
"prompt": ["Describe the difference between the two images."],
|
|
},
|
|
"mixture_text_image": {
|
|
"media": [
|
|
[],
|
|
[
|
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"
|
|
],
|
|
],
|
|
"prompt": [
|
|
"Who invented the internet?",
|
|
"Describe the scene in the image briefly.",
|
|
],
|
|
},
|
|
}
|
|
|
|
|
|
def add_multimodal_args(parser):
|
|
parser.add_argument(
|
|
"--model_type",
|
|
type=str,
|
|
choices=MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(),
|
|
help="Model type as specified in the HuggingFace model config.")
|
|
parser.add_argument("--modality",
|
|
type=str,
|
|
choices=[
|
|
"image", "video", "audio", "image_audio",
|
|
"multiple_image", "mixture_text_image"
|
|
],
|
|
default="image",
|
|
help="Media type being used for inference.")
|
|
parser.add_argument("--media",
|
|
type=str,
|
|
nargs="+",
|
|
help="A single or a list of media filepaths / urls.")
|
|
parser.add_argument("--num_frames",
|
|
type=int,
|
|
default=8,
|
|
help="The number of video frames to be sampled.")
|
|
parser.add_argument("--image_format",
|
|
type=str,
|
|
choices=["pt", "pil"],
|
|
default="pt",
|
|
help="The format of the image.")
|
|
parser.add_argument("--device",
|
|
type=str,
|
|
default="cpu",
|
|
help="The device to have the input on.")
|
|
parser.add_argument(
|
|
"--custom_module_dirs",
|
|
type=str,
|
|
nargs="+",
|
|
default=None,
|
|
help=
|
|
("Paths to an out-of-tree model directory which should be imported."
|
|
" This is useful to load a custom model. The directory should have a structure like:"
|
|
" <model_name>"
|
|
" ├── __init__.py"
|
|
" ├── <model_name>.py"
|
|
" └── <sub_dirs>"))
|
|
# Add multiturn conversation related parameters
|
|
parser.add_argument("--multiturn",
|
|
action="store_true",
|
|
help="Enable multi-turn conversation mode.")
|
|
parser.add_argument(
|
|
"--conversation_turns",
|
|
type=int,
|
|
default=2,
|
|
help="Number of conversation turns for automated testing.")
|
|
return parser
|
|
|
|
|
|
def add_lora_args(parser):
|
|
parser.add_argument("--load_lora",
|
|
default=False,
|
|
action='store_true',
|
|
help="Whether to load the LoRA model.")
|
|
parser.add_argument("--auto_model_name",
|
|
type=str,
|
|
default=None,
|
|
help="The auto model name in TRTLLM repo.")
|
|
return parser
|
|
|
|
|
|
def parse_arguments():
|
|
parser = argparse.ArgumentParser(
|
|
description="Multimodal models with the PyTorch workflow.")
|
|
parser = add_llm_args(parser)
|
|
parser.add_argument("--kv_cache_fraction", type=float, default=0.6)
|
|
parser = add_multimodal_args(parser)
|
|
parser = add_lora_args(parser)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_arguments()
|
|
if args.custom_module_dirs is not None:
|
|
for custom_module_dir in args.custom_module_dirs:
|
|
try:
|
|
import_custom_module_from_dir(custom_module_dir)
|
|
except Exception as e:
|
|
print(
|
|
f"Failed to import custom module from {custom_module_dir}: {e}"
|
|
)
|
|
raise e
|
|
|
|
lora_config = None
|
|
if args.load_lora:
|
|
assert args.auto_model_name is not None, "Please provide the auto model name to load LoRA config."
|
|
import importlib
|
|
models_module = importlib.import_module('tensorrt_llm._torch.models')
|
|
model_class = getattr(models_module, args.auto_model_name)
|
|
lora_config = model_class.lora_config(args.model_dir)
|
|
# For stability - explicitly set the LoRA GPU cache & CPU cache to have space for 2 adapters
|
|
lora_config.max_loras = 2
|
|
lora_config.max_cpu_loras = 2
|
|
|
|
llm, sampling_params = setup_llm(args, lora_config=lora_config)
|
|
|
|
image_format = args.image_format
|
|
if args.model_type is not None:
|
|
model_type = args.model_type
|
|
else:
|
|
model_type = json.load(
|
|
open(os.path.join(str(llm._hf_model_dir),
|
|
'config.json')))['model_type']
|
|
assert model_type in MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types(), \
|
|
f"Unsupported model_type: {model_type} found!\n" \
|
|
f"Supported types: {MULTIMODAL_PLACEHOLDER_REGISTRY.get_registered_model_types()}"
|
|
|
|
# If multiturn mode is enabled
|
|
if args.multiturn:
|
|
# Run predefined multiturn conversation examples
|
|
assert args.prompt is not None, "Please provide a prompt for multiturn conversation."
|
|
assert args.media is not None, "Please provide media for multiturn conversation."
|
|
# Determine how many turns to run
|
|
max_turns = min(args.conversation_turns, len(args.prompt))
|
|
generated_outputs = [] # Store generated outputs for return
|
|
|
|
# Initialize conversation history with the first prompt
|
|
conversation_history = args.prompt[0] if args.prompt else ""
|
|
|
|
for i in range(max_turns):
|
|
print(f"\n--- Turn {i+1} ---")
|
|
|
|
try:
|
|
# Use multimodal input loader to process input with conversation context
|
|
# Use accumulated conversation history instead of just the current prompt
|
|
cur_prompt = conversation_history
|
|
inputs = default_multimodal_input_loader(
|
|
tokenizer=llm.tokenizer,
|
|
model_dir=llm._hf_model_dir,
|
|
model_type=model_type,
|
|
modality=args.modality,
|
|
prompts=[cur_prompt],
|
|
media=args.media,
|
|
image_data_format="pt",
|
|
num_frames=8,
|
|
device="cpu")
|
|
|
|
lora_request = None
|
|
if args.load_lora:
|
|
if model_class is None:
|
|
raise ValueError(
|
|
"model_class must be provided when load_lora is True"
|
|
)
|
|
lora_request = model_class.lora_request(
|
|
len(inputs), args.modality, llm._hf_model_dir)
|
|
|
|
# Generate response
|
|
outputs = llm.generate(inputs,
|
|
sampling_params,
|
|
lora_request=lora_request)
|
|
assert outputs and len(
|
|
outputs) > 0 and outputs[0].outputs and len(
|
|
outputs[0].outputs) > 0
|
|
response = outputs[0].outputs[0].text.strip()
|
|
|
|
# Store generated output
|
|
generated_outputs.append({
|
|
"turn": i + 1,
|
|
"user_input": cur_prompt,
|
|
"assistant_response": response,
|
|
"media": args.media
|
|
})
|
|
|
|
conversation_history = conversation_history + "\n" + response
|
|
if i + 1 < len(args.prompt):
|
|
conversation_history = conversation_history + "\n" + args.prompt[
|
|
i + 1]
|
|
|
|
except Exception as e:
|
|
print(f"Error in turn {i+1}: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
for i, output in enumerate(generated_outputs):
|
|
print(
|
|
f"[{i}] Prompt: {output['user_input']!r}, Generated text: {output['assistant_response']!r}"
|
|
)
|
|
|
|
if args.log_kv_cache_events:
|
|
time.sleep(1) # Wait for events to be dispatched
|
|
events = llm.get_kv_cache_events(5)
|
|
print("=== KV_CACHE_EVENTS_START ===")
|
|
print(json.dumps(events, indent=2))
|
|
print("=== KV_CACHE_EVENTS_END ===")
|
|
|
|
return
|
|
|
|
# Original single-turn processing logic
|
|
# set prompts and media to example prompts and images if they are not provided
|
|
if args.prompt is None:
|
|
args.prompt = example_medias_and_prompts[args.modality]["prompt"]
|
|
if args.media is None:
|
|
args.media = example_medias_and_prompts[args.modality]["media"]
|
|
|
|
inputs = default_multimodal_input_loader(tokenizer=llm.tokenizer,
|
|
model_dir=str(llm._hf_model_dir),
|
|
model_type=model_type,
|
|
modality=args.modality,
|
|
prompts=args.prompt,
|
|
media=args.media,
|
|
image_data_format=image_format,
|
|
num_frames=args.num_frames,
|
|
device=args.device)
|
|
lora_request = None
|
|
if args.load_lora:
|
|
lora_request = model_class.lora_request(len(inputs), args.modality,
|
|
llm._hf_model_dir)
|
|
|
|
outputs = llm.generate(
|
|
inputs,
|
|
sampling_params,
|
|
lora_request=lora_request,
|
|
)
|
|
|
|
for i, output in enumerate(outputs):
|
|
prompt = args.prompt[i]
|
|
generated_text = output.outputs[0].text
|
|
print(f"[{i}] Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
if args.return_context_logits:
|
|
print(f"[{i}] Context logits: {output.context_logits}")
|
|
if args.return_generation_logits:
|
|
print(
|
|
f"[{i}] Generation logits: {output.outputs[0].generation_logits}"
|
|
)
|
|
if args.logprobs:
|
|
print(f"[{i}] Logprobs: {output.outputs[0].logprobs}")
|
|
|
|
if args.log_kv_cache_events:
|
|
time.sleep(1) # Wait for events to be dispatched
|
|
events = llm.get_kv_cache_events(5)
|
|
print("=== KV_CACHE_EVENTS_START ===")
|
|
print(json.dumps(events, indent=2))
|
|
print("=== KV_CACHE_EVENTS_END ===")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|