TensorRT-LLMs/examples/auto_deploy/build_and_run_flux.py
Ajinkya Rasane 8d7cda2318
[None][chore] Update the Flux autodeploy example (#8434)
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
2025-11-18 14:16:04 -08:00

383 lines
13 KiB
Python

import argparse
import os
from collections import abc
import modelopt.torch.opt as mto
import torch
import yaml
from diffusers import DiffusionPipeline
from PIL import Image
from tqdm import tqdm
from transformers import CLIPModel, CLIPProcessor
from tensorrt_llm._torch.auto_deploy.compile import CompileBackendRegistry
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer
from tensorrt_llm._torch.auto_deploy.utils._graph import load_buffers_and_params
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
torch._dynamo.config.cache_size_limit = 100
dtype_map = {
"Half": torch.float16,
"BFloat16": torch.bfloat16,
"Float": torch.float32,
}
def load_config(config_path):
"""Load configuration from YAML file.
Args:
config_path: Path to YAML config file.
Returns:
Dictionary with export, optimizer, and compile configurations.
"""
if not config_path:
raise ValueError("Config path is required. Use --config to specify a YAML config file.")
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found: {config_path}")
ad_logger.info(f"Loading config from {config_path}")
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Validate required sections
required_sections = ["export", "optimizer", "compile"]
for section in required_sections:
if section not in config:
raise ValueError(f"Config file missing required section: {section}")
return config
# TODO: Reuse the cache context from the original model
class TransformerWrapper(torch.nn.Module):
def __init__(self, compiled_model, config):
super().__init__()
self.model = compiled_model
self.config = config
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def cache_context(self, *args, **kwargs):
# Return a no-op context manager since the compiled model
# doesn't support this feature
from contextlib import contextmanager
@contextmanager
def noop_context():
yield
return noop_context()
def clip_model():
"""Load CLIP model for image-text similarity evaluation."""
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
model.eval()
if torch.cuda.is_available():
model = model.to("cuda")
return model, processor
def compute_clip_similarity(image_path: str, prompt: str, clip_model_and_processor) -> float:
"""Compute CLIP similarity score between generated image and text prompt.
Args:
image_path: Path to the generated image
prompt: Text prompt used to generate the image
clip_model_and_processor: Tuple of (CLIP model, CLIP processor)
Returns:
Similarity score between 0 and 1
"""
model, processor = clip_model_and_processor
image = Image.open(image_path).convert("RGB")
inputs = processor(text=[prompt], images=image, return_tensors="pt", padding=True)
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
image_embeds = outputs.image_embeds
text_embeds = outputs.text_embeds
# Normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
# Compute cosine similarity
similarity = (image_embeds @ text_embeds.T).squeeze().item()
return similarity
@torch.inference_mode()
def generate_image(pipe: DiffusionPipeline, prompt: str, image_name: str) -> None:
"""Generate an image using the given pipeline and prompt."""
image = pipe(
prompt,
output_type="pil",
num_inference_steps=30,
generator=torch.Generator("cuda").manual_seed(42),
).images[0]
image.save(image_name)
ad_logger.info(f"Image generated saved as {image_name}")
@torch.inference_mode()
def benchmark_backbone_standalone(
pipe, num_warmup=10, num_benchmark=100, model_name="flux-dev", model_dtype="Half"
):
"""Benchmark the backbone model directly without running the full pipeline."""
backbone = pipe.transformer if hasattr(pipe, "transformer") else pipe.unet
# Generate dummy inputs for the backbone
dummy_inputs = _gen_dummy_inp_flux(backbone)
# Warmup
ad_logger.info(f"Warming up: {num_warmup} iterations")
for _ in tqdm(range(num_warmup), desc="Warmup"):
_ = backbone(**dummy_inputs)
# Benchmark
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
ad_logger.info(f"Benchmarking: {num_benchmark} iterations")
times = []
for _ in tqdm(range(num_benchmark), desc="Benchmark"):
torch.cuda.profiler.cudart().cudaProfilerStart()
start_event.record()
_ = backbone(**dummy_inputs)
end_event.record()
torch.cuda.synchronize()
torch.cuda.profiler.cudart().cudaProfilerStop()
times.append(start_event.elapsed_time(end_event))
avg_latency = sum(times) / len(times)
times = sorted(times)
p50 = times[len(times) // 2]
p95 = times[int(len(times) * 0.95)]
p99 = times[int(len(times) * 0.99)]
ad_logger.info(f"\nBackbone-only inference latency ({model_dtype}):")
ad_logger.info(f" Average: {avg_latency:.2f} ms")
ad_logger.info(f" P50: {p50:.2f} ms")
ad_logger.info(f" P95: {p95:.2f} ms")
ad_logger.info(f" P99: {p99:.2f} ms")
return avg_latency
def torch_to(data, *args, **kwargs):
"""Try to recursively move the data to the specified args/kwargs."""
if isinstance(data, torch.Tensor):
return data.to(*args, **kwargs)
elif isinstance(data, (tuple, list)):
return type(data)([torch_to(val, *args, **kwargs) for val in data])
elif isinstance(data, abc.Mapping):
return {k: torch_to(val, *args, **kwargs) for k, val in data.items()}
return data
def _gen_dummy_inp_flux(backbone, min_bs=1):
cfg = backbone.config
text_maxlen = 512
img_dim = 4096
dtype = torch.bfloat16
dummy_input = {
"hidden_states": torch.randn(min_bs, img_dim, cfg.in_channels, dtype=dtype),
"encoder_hidden_states": torch.randn(
min_bs, text_maxlen, cfg.joint_attention_dim, dtype=dtype
),
"pooled_projections": torch.randn(min_bs, cfg.pooled_projection_dim, dtype=dtype),
"timestep": torch.ones(1, dtype=dtype),
"img_ids": torch.randn(img_dim, 3, dtype=torch.float32),
"txt_ids": torch.randn(text_maxlen, 3, dtype=torch.float32),
"return_dict": False,
"joint_attention_kwargs": {},
}
if cfg.guidance_embeds: # flux-dev
dummy_input["guidance"] = torch.full((1,), 3.5, dtype=torch.float32)
dummy_input = torch_to(dummy_input, device="cuda")
return dummy_input
def execution_device_getter(self):
return torch.device("cuda")
def execution_device_setter(self, value):
self.__dict__["_execution_device"] = torch.device("cuda")
def main(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument(
"--model",
type=str,
default="black-forest-labs/FLUX.1-dev",
help="The model to use for inference.",
)
parser.add_argument(
"--prompt",
type=str,
default="a photo of an astronaut riding a horse on mars",
help="The prompt to use for inference.",
)
parser.add_argument(
"--hf_inference",
action="store_true",
help="Whether to generate image with the base hf model in addition to autodeploy generation",
)
parser.add_argument(
"--restore_from", type=str, help="The quantized checkpoint path to restore the model from"
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Whether to benchmark the model",
)
parser.add_argument(
"--skip_image_generation",
action="store_true",
help="Whether to skip image generation",
)
parser.add_argument(
"--max_batch_size",
type=int,
default=1,
help="The max batch size to use for the model",
)
parser.add_argument(
"--image_path",
type=str,
default="output.png",
help="Path to save the generated image (default: output.png)",
)
parser.add_argument(
"--config",
type=str,
default=os.path.join(os.path.dirname(__file__), "flux_transforms.yaml"),
help="Path to YAML config file for export, optimizer, and compile settings (default: flux_transforms.yaml)",
)
args = parser.parse_args(argv)
# Validate max_batch_size
if args.max_batch_size <= 0:
raise ValueError(f"max_batch_size must be positive, got {args.max_batch_size}")
DiffusionPipeline._execution_device = property(execution_device_getter, execution_device_setter)
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Load CLIP model for similarity evaluation if generating images
clip_model_processor = None
if not args.skip_image_generation:
ad_logger.info("Loading CLIP model for similarity evaluation")
clip_model_processor = clip_model()
if args.hf_inference:
if not args.skip_image_generation:
ad_logger.info("Generating image with the torch pipeline")
hf_image_path = f"hf_{args.image_path}"
generate_image(pipe, args.prompt, hf_image_path)
# Compute CLIP similarity score
similarity = compute_clip_similarity(hf_image_path, args.prompt, clip_model_processor)
ad_logger.info(f"CLIP similarity score (HF): {similarity:.4f}")
if args.benchmark:
ad_logger.info("Benchmarking HuggingFace model")
latency = benchmark_backbone_standalone(pipe, model_dtype="BFloat16")
ad_logger.info(f"HuggingFace Average Inference Latency: {latency:.2f} ms")
model = pipe.transformer
flux_config = pipe.transformer.config
flux_kwargs = _gen_dummy_inp_flux(model, min_bs=args.max_batch_size)
# Load config from YAML
config = load_config(args.config)
# Restore quantizers
if args.restore_from:
ad_logger.info(f"Restoring model from {args.restore_from}")
try:
mto.restore(model, args.restore_from)
quant_state_dict = model.state_dict()
load_buffers_and_params(
model, quant_state_dict, strict_missing=False, strict_unexpected=False, clone=False
)
except Exception as e:
ad_logger.error(f"Failed to restore model from {args.restore_from}: {e}")
raise
# Export to graph module with config params
ad_logger.info("Exporting model to graph module...")
export_config = config["export"]
gm = torch_export_to_gm(
model,
args=(),
kwargs=flux_kwargs,
clone=export_config.get("clone", False),
strict=export_config.get("strict", False),
)
# Apply inference optimizer fusions
optimizer_config = config.get("optimizer")
if optimizer_config:
ad_logger.info("Applying inference optimizer fusions (FP8 and FP4)...")
optimizer = InferenceOptimizer(factory=None, config=optimizer_config)
gm = optimizer(cm=None, mod=gm)
ad_logger.info("Inference optimizer fusions applied successfully")
else:
ad_logger.info("No optimizer transforms configured, skipping optimizer fusions")
# Compile model with config params
compile_config = config["compile"]
backend = compile_config.get("backend", "torch-opt")
cuda_graph_batch_sizes = compile_config.get("cuda_graph_batch_sizes", None)
# Validate backend availability
if not CompileBackendRegistry.has(backend):
available = CompileBackendRegistry.list()
raise ValueError(f"Backend '{backend}' not found. Available backends: {available}")
ad_logger.info(f"Compiling model with backend: {backend}")
compiler_cls = CompileBackendRegistry.get(backend)
gm = compiler_cls(
gm,
args=(),
max_batch_size=args.max_batch_size,
kwargs=flux_kwargs,
cuda_graph_batch_sizes=cuda_graph_batch_sizes,
).compile()
del model
fx_model = TransformerWrapper(gm, flux_config)
pipe.transformer = fx_model
if not args.skip_image_generation:
ad_logger.info("Generating image with the exported auto-deploy model")
generate_image(pipe, args.prompt, args.image_path)
# Compute CLIP similarity score
similarity = compute_clip_similarity(args.image_path, args.prompt, clip_model_processor)
ad_logger.info(f"CLIP similarity score (AutoDeploy): {similarity:.4f}")
if args.benchmark:
ad_logger.info("Benchmarking AutoDeploy model")
latency = benchmark_backbone_standalone(pipe, model_dtype="BFloat16")
ad_logger.info(f"AutoDeploy Average Inference Latency: {latency:.2f} ms")
if __name__ == "__main__":
main()