TensorRT-LLMs/tests/integration/defs/examples/test_flux.py
Ajinkya Rasane 8d7cda2318
[None][chore] Update the Flux autodeploy example (#8434)
Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Co-authored-by: Frida Hou <201670829+Fridah-nv@users.noreply.github.com>
2025-11-18 14:16:04 -08:00

134 lines
4.8 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration test for build_and_run_flux.py with multiple quantization formats."""
import importlib.util
import os
import pytest
import torch
from build_and_run_flux import clip_model as load_clip_model
from build_and_run_flux import compute_clip_similarity, main
from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger
# Check if CLIP is available
CLIP_AVAILABLE = importlib.util.find_spec("transformers") is not None
class FluxTestConfig:
"""Configuration for Flux integration test."""
MODEL_ID = os.environ.get("FLUX_MODEL_ID", "black-forest-labs/FLUX.1-dev")
PROMPT = "a photo of an astronaut riding a horse on mars"
MIN_CLIP_SIMILARITY = 0.25
NUM_INFERENCE_STEPS = 20
MAX_BATCH_SIZE = 1
BACKEND = "torch-opt"
# Checkpoint paths for different quantization formats
# These can be set via environment variables or test parameters
FP8_CHECKPOINT = os.environ.get("FLUX_FP8_CHECKPOINT")
FP4_CHECKPOINT = os.environ.get("FLUX_FP4_CHECKPOINT")
@pytest.fixture(scope="module")
def clip_model():
"""Pytest fixture for loading CLIP model once per test module."""
if not CLIP_AVAILABLE:
pytest.skip("CLIP not available")
return load_clip_model()
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for Flux model")
@pytest.mark.slow # Mark as slow test
class TestFluxIntegration:
"""Integration tests for Flux model with different quantization formats."""
@pytest.mark.parametrize(
"precision,checkpoint_path",
[
("bf16", None),
pytest.param(
"fp8",
FluxTestConfig.FP8_CHECKPOINT,
marks=pytest.mark.skipif(
FluxTestConfig.FP8_CHECKPOINT is None
or not (
FluxTestConfig.FP8_CHECKPOINT
and os.path.exists(FluxTestConfig.FP8_CHECKPOINT)
),
reason="FP8 checkpoint not available",
),
),
pytest.param(
"fp4",
FluxTestConfig.FP4_CHECKPOINT,
marks=pytest.mark.skipif(
FluxTestConfig.FP4_CHECKPOINT is None
or not (
FluxTestConfig.FP4_CHECKPOINT
and os.path.exists(FluxTestConfig.FP4_CHECKPOINT)
),
reason="FP4 checkpoint not available",
),
),
],
)
def test_flux_e2e_with_clip_validation(self, precision, checkpoint_path, clip_model, tmp_path):
"""End-to-end test for Flux model with CLIP similarity validation.
Tests:
1. Call build_and_run_flux.py main function
2. Validate generated image quality using CLIP similarity
"""
output_image = tmp_path / f"flux_{precision}_output.png"
# Build arguments for main function
args = [
"--model",
FluxTestConfig.MODEL_ID,
"--prompt",
FluxTestConfig.PROMPT,
"--image_path",
str(output_image),
"--max_batch_size",
str(FluxTestConfig.MAX_BATCH_SIZE),
]
# Add checkpoint if provided
if checkpoint_path:
args.extend(["--restore_from", checkpoint_path])
ad_logger.info(f"Running main with args for {precision}: {' '.join(args)}")
# Call main function directly with args
main(args)
# Verify image was generated
assert output_image.exists(), f"Output image not found at {output_image}"
# Compute CLIP similarity
similarity = compute_clip_similarity(str(output_image), FluxTestConfig.PROMPT, clip_model)
ad_logger.info(f"CLIP similarity score for {precision}: {similarity:.4f}")
# Assert similarity is above threshold
assert similarity >= FluxTestConfig.MIN_CLIP_SIMILARITY, (
f"CLIP similarity {similarity:.4f} is below threshold {FluxTestConfig.MIN_CLIP_SIMILARITY}. "
f"Image may not match prompt: '{FluxTestConfig.PROMPT}'"
)
ad_logger.info(f"✓ Test passed for {precision}: CLIP similarity = {similarity:.4f}")