TensorRT-LLMs/examples/models/contrib/mmdit/sample.py
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

267 lines
9.9 KiB
Python

import argparse
import json
import os
from functools import wraps
from typing import List, Optional
import numpy as np
import tensorrt as trt
import torch
from cuda import cudart
from diffusers import StableDiffusion3Pipeline
from diffusers.models.modeling_outputs import Transformer2DModelOutput
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.models.mmdit_sd3.config import SD3Transformer2DModelConfig
from tensorrt_llm.runtime.session import Session
def CUASSERT(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
)
if len(cuda_ret) > 1:
return cuda_ret[1:]
return None
class TllmMMDiT(object):
def __init__(self,
config,
skip_layers: Optional[List[int]] = None,
debug_mode: bool = True,
stream: torch.cuda.Stream = None,
pos_embedder: torch.nn.Module = None):
self.dtype = config['pretrained_config']['dtype']
# Update config with `skip_layers`
config['pretrained_config']['skip_layers'] = skip_layers
self.config = SD3Transformer2DModelConfig.from_dict(
config['pretrained_config'])
rank = tensorrt_llm.mpi_rank()
world_size = config['pretrained_config']['mapping']['world_size']
cp_size = config['pretrained_config']['mapping']['cp_size']
tp_size = config['pretrained_config']['mapping']['tp_size']
pp_size = config['pretrained_config']['mapping']['pp_size']
gpus_per_node = config['pretrained_config']['mapping']['gpus_per_node']
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(world_size=world_size,
rank=rank,
cp_size=cp_size,
tp_size=tp_size,
pp_size=1,
gpus_per_node=gpus_per_node)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(self.device)
CUASSERT(cudart.cudaSetDevice(local_rank))
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(args.tllm_model_dir, f"rank{rank}.engine")
logger.info(f'Loading engine from {engine_file}')
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = [
'hidden_states', 'encoder_hidden_states', 'pooled_projections',
'timestep', 'output'
]
found_tensor_names = [
self.session.engine.get_tensor_name(i)
for i in range(self.session.engine.num_io_tensors)
]
if not self.debug_mode and set(expected_tensor_names) != set(
found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError(
"Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(
set(found_tensor_names) - set(expected_tensor_names))
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, outputs_shape):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(
name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
if self.debug_mode:
shape = list(self.session.engine.get_tensor_shape(name))
if shape[0] == -1:
shape[0] = outputs_shape['output'][0]
else:
shape = outputs_shape[name]
self.outputs[name] = torch.empty(shape,
dtype=self._tensor_dtype(name),
device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
@cuda_stream_guard
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
block_controlnet_hidden_states: List = None,
joint_attention_kwargs=None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
):
if block_controlnet_hidden_states is not None:
raise NotImplementedError("ControlNet is not supported yet.")
if joint_attention_kwargs is not None:
if joint_attention_kwargs.get("scale", None) is not None:
raise NotImplementedError("LoRA is not supported yet.")
if "ip_adapter_image_embeds" in joint_attention_kwargs:
raise NotImplementedError("IP-Adapter is not supported yet.")
if skip_layers is not None:
raise RuntimeWarning(
"Please initialize model with `skip_layers` instead of passing via `forward`."
)
self._setup(outputs_shape={'output': hidden_states.shape})
if not self.buffer_allocated:
raise RuntimeError('Buffer not allocated, please call setup first!')
inputs = {
'hidden_states':
hidden_states.to(str_dtype_to_torch(self.dtype)),
'encoder_hidden_states':
encoder_hidden_states.to(str_dtype_to_torch(self.dtype)),
'pooled_projections':
pooled_projections.to(str_dtype_to_torch(self.dtype)),
'timestep':
timestep.to(str_dtype_to_torch(self.dtype)),
}
for k, v in inputs.items():
inputs[k] = v.cuda().contiguous()
self.inputs.update(**inputs)
self.session.set_shapes(self.inputs)
ok = self.session.run(self.inputs, self.outputs,
self.stream.cuda_stream)
if not ok:
raise RuntimeError('Executing TRT engine failed!')
output = self.outputs['output'].to(hidden_states.device)
if self.debug_mode:
torch.cuda.synchronize()
output_np = {
k: v.cpu().float().numpy()
for k, v in self.outputs.items()
}
np.savez("tllm_output.npz", **output_np)
if not return_dict:
return (output, )
else:
return Transformer2DModelOutput(sample=output)
def main(args):
tensorrt_llm.logger.set_level(args.log_level)
assert torch.cuda.is_available()
config_file = os.path.join(args.tllm_model_dir, 'config.json')
with open(config_file) as f:
config = json.load(f)
pipe = StableDiffusion3Pipeline.from_pretrained(
config['pretrained_config']['model_path'], torch_dtype=torch.float16)
pipe.to("cuda")
# replace sd3.5 transformer with TRTLLM model
torch_pos_embed = pipe.transformer.pos_embed
del pipe.transformer
torch.cuda.empty_cache()
# Load model:
model = TllmMMDiT(config,
debug_mode=args.debug_mode,
pos_embedder=torch_pos_embed)
pipe.transformer = model
image = pipe(args.prompt,
height=1024,
width=1024,
guidance_scale=4.5,
num_inference_steps=40,
generator=torch.Generator("cpu").manual_seed(0)).images[0]
if tensorrt_llm.mpi_rank() == 0:
image.save("sd3.5-mmdit.png")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'prompt',
nargs='*',
default=
"A capybara holding a sign that reads 'Hello World' in the forrest.",
help="Text prompt(s) to guide image generation")
parser.add_argument("--tllm_model_dir",
type=str,
default='./engine_outputs/')
parser.add_argument("--gpus_per_node", type=int, default=8)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument("--debug_mode", action='store_true')
args = parser.parse_args()
main(args)