TensorRT-LLMs/examples/models/contrib/dit/sample.py
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

269 lines
9.8 KiB
Python

import argparse
import json
import os
from functools import wraps
import tensorrt as trt
import torch
from cuda import cudart
from diffusion import DiTDiffusionPipeline
from torchvision.utils import save_image
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_torch, trt_dtype_to_torch
from tensorrt_llm.logger import logger
from tensorrt_llm.plugin.plugin import CustomAllReduceHelper
from tensorrt_llm.runtime.session import Session, TensorInfo
def CUASSERT(cuda_ret):
err = cuda_ret[0]
if err != cudart.cudaError_t.cudaSuccess:
raise RuntimeError(
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
)
if len(cuda_ret) > 1:
return cuda_ret[1:]
return None
class TllmDiT(object):
def __init__(self,
config,
debug_mode=True,
stream: torch.cuda.Stream = None):
self.dtype = config['pretrained_config']['dtype']
rank = tensorrt_llm.mpi_rank()
world_size = config['pretrained_config']['mapping']['world_size']
cp_size = config['pretrained_config']['mapping']['cp_size']
tp_size = config['pretrained_config']['mapping']['tp_size']
pp_size = config['pretrained_config']['mapping']['pp_size']
assert pp_size == 1
self.mapping = tensorrt_llm.Mapping(world_size=world_size,
rank=rank,
cp_size=cp_size,
tp_size=tp_size,
pp_size=1,
gpus_per_node=args.gpus_per_node)
local_rank = rank % self.mapping.gpus_per_node
self.device = torch.device(f'cuda:{local_rank}')
torch.cuda.set_device(self.device)
CUASSERT(cudart.cudaSetDevice(local_rank))
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_file = os.path.join(args.tllm_model_dir, f"rank{rank}.engine")
logger.info(f'Loading engine from {engine_file}')
with open(engine_file, "rb") as f:
engine_buffer = f.read()
assert engine_buffer is not None
self.session = Session.from_serialized_engine(engine_buffer)
self.debug_mode = debug_mode
self.inputs = {}
self.outputs = {}
self.buffer_allocated = False
expected_tensor_names = ['latent', 'timestep', 'label', 'output']
if self.mapping.tp_size > 1:
self.buffer, self.all_reduce_workspace = CustomAllReduceHelper.allocate_workspace(
self.mapping,
CustomAllReduceHelper.max_workspace_size_auto(
self.mapping.tp_size))
self.inputs['all_reduce_workspace'] = self.all_reduce_workspace
expected_tensor_names += ['all_reduce_workspace']
found_tensor_names = [
self.session.engine.get_tensor_name(i)
for i in range(self.session.engine.num_io_tensors)
]
if not self.debug_mode and set(expected_tensor_names) != set(
found_tensor_names):
logger.error(
f"The following expected tensors are not found: {set(expected_tensor_names).difference(set(found_tensor_names))}"
)
logger.error(
f"Those tensors in engine are not expected: {set(found_tensor_names).difference(set(expected_tensor_names))}"
)
logger.error(f"Expected tensor names: {expected_tensor_names}")
logger.error(f"Found tensor names: {found_tensor_names}")
raise RuntimeError(
"Tensor names in engine are not the same as expected.")
if self.debug_mode:
self.debug_tensors = list(
set(found_tensor_names) - set(expected_tensor_names))
def _tensor_dtype(self, name):
# return torch dtype given tensor name for convenience
dtype = trt_dtype_to_torch(self.session.engine.get_tensor_dtype(name))
return dtype
def _setup(self, batch_size):
for i in range(self.session.engine.num_io_tensors):
name = self.session.engine.get_tensor_name(i)
if self.session.engine.get_tensor_mode(
name) == trt.TensorIOMode.OUTPUT:
shape = list(self.session.engine.get_tensor_shape(name))
shape[0] = batch_size // 2 if name in [
'cond_eps', 'uncond_eps'
] else batch_size
self.outputs[name] = torch.empty(shape,
dtype=self._tensor_dtype(name),
device=self.device)
self.buffer_allocated = True
def cuda_stream_guard(func):
"""Sync external stream and set current stream to the one bound to the session. Reset on exit.
"""
@wraps(func)
def wrapper(self, *args, **kwargs):
external_stream = torch.cuda.current_stream()
if external_stream != self.stream:
external_stream.synchronize()
torch.cuda.set_stream(self.stream)
ret = func(self, *args, **kwargs)
if external_stream != self.stream:
self.stream.synchronize()
torch.cuda.set_stream(external_stream)
return ret
return wrapper
@cuda_stream_guard
def forward(self, latent: torch.Tensor, timestep: torch.Tensor,
label: torch.Tensor):
"""
Forward pass of DiT.
latent: (N, C, H, W)
timestep: (N,)
label: (N,)
"""
self._setup(latent.shape[0])
if not self.buffer_allocated:
raise RuntimeError('Buffer not allocated, please call setup first!')
inputs = {
'latent': latent.to(str_dtype_to_torch(self.dtype)),
"timestep": timestep.int(),
"label": label.int()
}
self.inputs.update(**inputs)
self.session.set_shapes(self.inputs)
ok = self.session.run(self.inputs, self.outputs,
self.stream.cuda_stream)
if not ok:
raise RuntimeError('Executing TRT engine failed!')
if self.debug_mode:
torch.cuda.synchronize()
for k, v in self.inputs.items():
print(k, v.sum())
for k, v in self.outputs.items():
print(k, v.sum())
return self.outputs['output']
def vae_decode(samples, engine_path):
# Load standard plugins
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(TRT_LOGGER, namespace="")
logger.info(f'Loading vae engine from {engine_path}')
with open(engine_path, 'rb') as f:
engine_buffer = f.read()
logger.info(f'Creating session from engine {engine_path}')
session_vae = Session.from_serialized_engine(engine_buffer)
inputs = {'input': samples}
output_info = session_vae.infer_shapes(
[TensorInfo('input', trt.DataType.FLOAT, samples.shape)])
outputs = {
t.name:
torch.empty(tuple(t.shape),
dtype=trt_dtype_to_torch(t.dtype),
device='cuda')
for t in output_info
}
stream = torch.cuda.current_stream().cuda_stream
ok = session_vae.run(inputs, outputs, stream)
assert ok, "Runtime execution failed for vae session"
samples = outputs['output']
return samples
def main(args):
tensorrt_llm.logger.set_level(args.log_level)
torch.manual_seed(args.seed)
assert torch.cuda.is_available()
device = "cuda"
# Load model:
config_file = os.path.join(args.tllm_model_dir, 'config.json')
with open(config_file) as f:
config = json.load(f)
model = TllmDiT(config, debug_mode=args.debug_mode)
diffusion = DiTDiffusionPipeline(model.forward,
timestep_respacing=args.num_sampling_steps)
latent_size = args.image_size // 8
latent = torch.randn(args.batch_size,
4,
latent_size,
latent_size,
device=device)
labels = torch.randint(args.num_classes, [args.batch_size], device=device)
latent = torch.cat([latent, latent], 0)
labels_null = torch.tensor([1000] * args.batch_size, device=device)
labels = torch.cat([labels, labels_null], 0)
samples = diffusion.run(latent, labels)
samples, _ = samples.chunk(2, dim=0)
samples = vae_decode(samples / 0.18215, args.vae_decoder_engine)
save_image(samples,
"sample.png",
nrow=4,
normalize=True,
value_range=(-1, 1))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--vae_decoder_engine',
type=str,
default='vae_decoder/plan/visual_encoder_fp16.plan',
help='')
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--image-size",
type=int,
choices=[256, 512],
default=512)
parser.add_argument("--num-classes", type=int, default=1000)
parser.add_argument("--num-sampling-steps", type=int, default=250)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--tllm_model_dir",
type=str,
default='./engine_outputs/')
parser.add_argument("--gpus_per_node", type=int, default=8)
parser.add_argument('--log_level', type=str, default='info')
parser.add_argument("--debug_mode", type=bool, default=False)
args = parser.parse_args()
main(args)