TensorRT-LLMs/examples/auto_deploy/build_and_run_flux.py
Kaiyu Xie 2631f21089
Update (#2978)
Signed-off-by: Kaiyu Xie <26294424+kaiyux@users.noreply.github.com>
2025-03-23 16:39:35 +08:00

163 lines
5.7 KiB
Python

import argparse
from typing import Any
import modelopt.torch.opt as mto
import torch
from diffusers import DiffusionPipeline
from tensorrt_llm._torch.auto_deploy.compile import compile_and_capture
from tensorrt_llm._torch.auto_deploy.transformations.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transformations.library.fusion import fuse_gemms
from tensorrt_llm._torch.auto_deploy.transformations.library.quantization import quantize
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
torch._dynamo.config.cache_size_limit = 100
def generate_image(pipe: DiffusionPipeline, prompt: str, image_name: str) -> None:
"""Generate an image using the given pipeline and prompt."""
seed = 42
image = pipe(
prompt,
output_type="pil",
num_inference_steps=30,
generator=torch.Generator("cuda").manual_seed(seed),
).images[0]
image.save(image_name)
ad_logger.info(f"Image generated saved as {image_name}")
@torch.inference_mode()
def benchmark_model(model, generate_dummy_inputs, benchmarking_runs=200, warmup_runs=25) -> float:
"""Returns the latency of the model in seconds."""
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
input_data = generate_dummy_inputs()
for _ in range(warmup_runs):
_ = model(**input_data)
torch.cuda.synchronize()
torch.cuda.profiler.cudart().cudaProfilerStart()
start_event.record()
for _ in range(benchmarking_runs):
_ = model(**input_data)
end_event.record()
end_event.synchronize()
torch.cuda.profiler.cudart().cudaProfilerStop()
return start_event.elapsed_time(end_event) / benchmarking_runs / 1000
def generate_dummy_inputs(
device: str = "cuda", model_dtype: torch.dtype = torch.bfloat16
) -> dict[str, Any]:
"""Generate dummy inputs for the flux transformer."""
assert model_dtype in [torch.bfloat16, torch.float16], (
"Model dtype must be either bfloat16 or float16"
)
dummy_input = {}
text_maxlen = 512
dummy_input["hidden_states"] = torch.randn(1, 4096, 64, dtype=model_dtype, device=device)
dummy_input["timestep"] = torch.tensor(data=[1.0] * 1, dtype=model_dtype, device=device)
dummy_input["guidance"] = torch.full((1,), 3.5, dtype=torch.float32, device=device)
dummy_input["pooled_projections"] = torch.randn(1, 768, dtype=model_dtype, device=device)
dummy_input["encoder_hidden_states"] = torch.randn(
1, text_maxlen, 4096, dtype=model_dtype, device=device
)
dummy_input["txt_ids"] = torch.randn(text_maxlen, 3, dtype=torch.float32, device=device)
dummy_input["img_ids"] = torch.randn(4096, 3, dtype=torch.float32, device=device)
dummy_input["joint_attention_kwargs"] = {}
dummy_input["return_dict"] = False
return dummy_input
def execution_device_getter(self):
return torch.device("cuda") # Always return CUDA
def execution_device_setter(self, value):
self.__dict__["_execution_device"] = torch.device("cuda") # Force CUDA
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="The model to use for inference.",
)
parser.add_argument(
"--prompt",
type=str,
default="a photo of an astronaut riding a horse on mars",
help="The prompt to use for inference.",
)
parser.add_argument(
"--hf_inference",
action="store_true",
help="Whether to generate image with the base hf model in addition to autodeploy generation",
)
parser.add_argument(
"--restore_from", type=str, help="The quantized checkpoint path to restore the model from"
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Whether to benchmark the model",
)
parser.add_argument(
"--skip_image_generation",
action="store_true",
help="Whether to skip image generation",
)
args = parser.parse_args()
DiffusionPipeline._execution_device = property(execution_device_getter, execution_device_setter)
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.bfloat16)
pipe.to("cuda")
if args.hf_inference:
if not args.skip_image_generation:
ad_logger.info("Generating image with the torch pipeline")
generate_image(pipe, args.prompt, "hf_mars_horse.png")
if args.benchmark:
latency = benchmark_model(pipe.transformer, generate_dummy_inputs)
ad_logger.info(f"HuggingFace Latency: {latency} seconds")
model = pipe.transformer
if args.restore_from:
ad_logger.info(f"Restoring model from {args.restore_from}")
mto.restore(model, args.restore_from)
flux_config = pipe.transformer.config
flux_kwargs = generate_dummy_inputs()
gm = torch_export_to_gm(model, args=(), kwargs=flux_kwargs, clone=True)
if args.restore_from:
quant_state_dict = model.state_dict()
gm = quantize(gm, {}).to("cuda")
gm.load_state_dict(quant_state_dict, strict=False)
gm = fuse_gemms(gm)
gm = compile_and_capture(gm, backend="torch-opt", args=(), kwargs=flux_kwargs)
del model
fx_model = gm
fx_model.config = flux_config
pipe.transformer = fx_model
if not args.skip_image_generation:
ad_logger.info("Generating image with the exported auto-deploy model")
generate_image(pipe, args.prompt, "autodeploy_mars_horse_gm.png")
if args.benchmark:
latency = benchmark_model(fx_model, generate_dummy_inputs)
ad_logger.info(f"AutoDeploy Latency: {latency} seconds")
if __name__ == "__main__":
main()