mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
164 lines
5.7 KiB
Python
164 lines
5.7 KiB
Python
import argparse
|
|
from typing import Any
|
|
|
|
import modelopt.torch.opt as mto
|
|
import torch
|
|
from diffusers import DiffusionPipeline
|
|
|
|
from tensorrt_llm._torch.auto_deploy.compile import CompileBackendRegistry
|
|
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
|
|
from tensorrt_llm._torch.auto_deploy.transformations.library.fusion import fuse_gemms
|
|
from tensorrt_llm._torch.auto_deploy.transformations.library.quantization import quantize
|
|
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
|
|
|
|
torch._dynamo.config.cache_size_limit = 100
|
|
|
|
|
|
def generate_image(pipe: DiffusionPipeline, prompt: str, image_name: str) -> None:
|
|
"""Generate an image using the given pipeline and prompt."""
|
|
seed = 42
|
|
image = pipe(
|
|
prompt,
|
|
output_type="pil",
|
|
num_inference_steps=30,
|
|
generator=torch.Generator("cuda").manual_seed(seed),
|
|
).images[0]
|
|
image.save(image_name)
|
|
ad_logger.info(f"Image generated saved as {image_name}")
|
|
|
|
|
|
@torch.inference_mode()
|
|
def benchmark_model(model, generate_dummy_inputs, benchmarking_runs=200, warmup_runs=25) -> float:
|
|
"""Returns the latency of the model in seconds."""
|
|
start_event = torch.cuda.Event(enable_timing=True)
|
|
end_event = torch.cuda.Event(enable_timing=True)
|
|
|
|
input_data = generate_dummy_inputs()
|
|
|
|
for _ in range(warmup_runs):
|
|
_ = model(**input_data)
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.profiler.cudart().cudaProfilerStart()
|
|
start_event.record()
|
|
for _ in range(benchmarking_runs):
|
|
_ = model(**input_data)
|
|
end_event.record()
|
|
end_event.synchronize()
|
|
torch.cuda.profiler.cudart().cudaProfilerStop()
|
|
|
|
return start_event.elapsed_time(end_event) / benchmarking_runs / 1000
|
|
|
|
|
|
def generate_dummy_inputs(
|
|
device: str = "cuda", model_dtype: torch.dtype = torch.bfloat16
|
|
) -> dict[str, Any]:
|
|
"""Generate dummy inputs for the flux transformer."""
|
|
assert model_dtype in [torch.bfloat16, torch.float16], (
|
|
"Model dtype must be either bfloat16 or float16"
|
|
)
|
|
dummy_input = {}
|
|
text_maxlen = 512
|
|
dummy_input["hidden_states"] = torch.randn(1, 4096, 64, dtype=model_dtype, device=device)
|
|
dummy_input["timestep"] = torch.tensor(data=[1.0] * 1, dtype=model_dtype, device=device)
|
|
dummy_input["guidance"] = torch.full((1,), 3.5, dtype=torch.float32, device=device)
|
|
dummy_input["pooled_projections"] = torch.randn(1, 768, dtype=model_dtype, device=device)
|
|
dummy_input["encoder_hidden_states"] = torch.randn(
|
|
1, text_maxlen, 4096, dtype=model_dtype, device=device
|
|
)
|
|
dummy_input["txt_ids"] = torch.randn(text_maxlen, 3, dtype=torch.float32, device=device)
|
|
dummy_input["img_ids"] = torch.randn(4096, 3, dtype=torch.float32, device=device)
|
|
dummy_input["joint_attention_kwargs"] = {}
|
|
dummy_input["return_dict"] = False
|
|
|
|
return dummy_input
|
|
|
|
|
|
def execution_device_getter(self):
|
|
return torch.device("cuda") # Always return CUDA
|
|
|
|
|
|
def execution_device_setter(self, value):
|
|
self.__dict__["_execution_device"] = torch.device("cuda") # Force CUDA
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model",
|
|
type=str,
|
|
default="black-forest-labs/FLUX.1-dev",
|
|
help="The model to use for inference.",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
default="a photo of an astronaut riding a horse on mars",
|
|
help="The prompt to use for inference.",
|
|
)
|
|
parser.add_argument(
|
|
"--hf_inference",
|
|
action="store_true",
|
|
help="Whether to generate image with the base hf model in addition to autodeploy generation",
|
|
)
|
|
parser.add_argument(
|
|
"--restore_from", type=str, help="The quantized checkpoint path to restore the model from"
|
|
)
|
|
parser.add_argument(
|
|
"--benchmark",
|
|
action="store_true",
|
|
help="Whether to benchmark the model",
|
|
)
|
|
parser.add_argument(
|
|
"--skip_image_generation",
|
|
action="store_true",
|
|
help="Whether to skip image generation",
|
|
)
|
|
args = parser.parse_args()
|
|
DiffusionPipeline._execution_device = property(execution_device_getter, execution_device_setter)
|
|
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.bfloat16)
|
|
pipe.to("cuda")
|
|
|
|
if args.hf_inference:
|
|
if not args.skip_image_generation:
|
|
ad_logger.info("Generating image with the torch pipeline")
|
|
generate_image(pipe, args.prompt, "hf_mars_horse.png")
|
|
if args.benchmark:
|
|
latency = benchmark_model(pipe.transformer, generate_dummy_inputs)
|
|
ad_logger.info(f"HuggingFace Latency: {latency} seconds")
|
|
model = pipe.transformer
|
|
if args.restore_from:
|
|
ad_logger.info(f"Restoring model from {args.restore_from}")
|
|
mto.restore(model, args.restore_from)
|
|
flux_config = pipe.transformer.config
|
|
flux_kwargs = generate_dummy_inputs()
|
|
|
|
gm = torch_export_to_gm(model, args=(), kwargs=flux_kwargs, clone=True)
|
|
|
|
if args.restore_from:
|
|
quant_state_dict = model.state_dict()
|
|
quantize(gm, {}).to("cuda")
|
|
gm.load_state_dict(quant_state_dict, strict=False)
|
|
|
|
fuse_gemms(gm)
|
|
|
|
compiler_cls = CompileBackendRegistry.get("torch-opt")
|
|
gm = compiler_cls(gm, args=(), kwargs=flux_kwargs).compile()
|
|
|
|
del model
|
|
fx_model = gm
|
|
fx_model.config = flux_config
|
|
pipe.transformer = fx_model
|
|
if not args.skip_image_generation:
|
|
ad_logger.info("Generating image with the exported auto-deploy model")
|
|
generate_image(pipe, args.prompt, "autodeploy_mars_horse_gm.png")
|
|
|
|
if args.benchmark:
|
|
latency = benchmark_model(fx_model, generate_dummy_inputs)
|
|
ad_logger.info(f"AutoDeploy Latency: {latency} seconds")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|