TensorRT-LLMs/examples/models/contrib/dit/diffusion.py
bhsueh_NV 322ac565fc
chore: clean some ci of qa test (#3083)
* move some models to examples/models/contrib

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* update the document

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove arctic, blip2, cogvlm, dbrx from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove tests of dit, mmdit and stdit from qa test

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* remove grok, jais, sdxl, skywork, smaug from qa test list

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* re-organize the glm examples

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix issues after running pre-commit

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix some typo in glm_4_9b readme

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

* fix bug

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>

---------

Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
2025-03-31 14:30:41 +08:00

117 lines
5.1 KiB
Python

import numpy as np
import torch
from tqdm.auto import tqdm
def space_timesteps(num_timesteps, timestep_respacing):
if num_timesteps < timestep_respacing:
raise ValueError(
f"cannot divide section of {num_timesteps} steps into {timestep_respacing}"
)
frac_stride = (num_timesteps - 1) / (timestep_respacing - 1)
cur_idx = 0.0
taken_steps = []
for _ in range(timestep_respacing):
taken_steps.append(round(cur_idx))
cur_idx += frac_stride
return set(taken_steps)
def _extract_into_tensor(arr, timesteps, broadcast_shape):
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + torch.zeros(broadcast_shape, device=timesteps.device)
class DiTDiffusionPipeline:
def __init__(self, dit_model, timestep_respacing, diffusion_steps=1000):
self.dit_model = dit_model
self.timesteps = space_timesteps(diffusion_steps, timestep_respacing)
self.timestep_map = []
self.original_num_steps = diffusion_steps
betas = self._setup_betas()
self.betas = betas
self.num_timesteps = int(betas.shape[0])
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
1)
self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
(1.0 - self.alphas_cumprod))
self.posterior_log_variance_clipped = np.log(
np.append(self.posterior_variance[1],
self.posterior_variance[1:])) if len(
self.posterior_variance) > 1 else np.array([])
self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) /
(1.0 - self.alphas_cumprod))
self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
np.sqrt(alphas) /
(1.0 - self.alphas_cumprod))
def _setup_betas(self):
scale = 1000 / self.original_num_steps
betas = np.linspace(scale * 0.0001,
scale * 0.02,
self.original_num_steps,
dtype=np.float64)
last_alpha_cumprod = 1.0
new_betas = []
alphas = 1.0 - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
for i, alpha_cumprod in enumerate(alphas_cumprod):
if i in self.timesteps:
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
last_alpha_cumprod = alpha_cumprod
self.timestep_map.append(i)
betas = np.array(new_betas)
return betas
def run(self, sample, labels):
indices = list(range(self.num_timesteps))[::-1]
for i in tqdm(indices):
t = torch.tensor([i] * sample.shape[0], device=sample.device)
B, C = sample.shape[:2]
assert t.shape == (B, )
map_tensor = torch.tensor(self.timestep_map,
device=t.device,
dtype=t.dtype)
model_output = self.dit_model(sample, map_tensor[t], labels)
assert model_output.shape == (B, C * 2, *sample.shape[2:])
model_output, model_var_values = torch.split(model_output, C, dim=1)
min_log = _extract_into_tensor(self.posterior_log_variance_clipped,
t, sample.shape)
max_log = _extract_into_tensor(np.log(self.betas), t, sample.shape)
frac = (model_var_values + 1) / 2
model_log_variance = frac * max_log + (1 - frac) * min_log
pred_xstart = (
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
sample.shape) * sample -
_extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
sample.shape) * model_output)
model_mean = (_extract_into_tensor(self.posterior_mean_coef1, t,
sample.shape) * pred_xstart +
_extract_into_tensor(self.posterior_mean_coef2, t,
sample.shape) * sample)
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == sample.shape
noise = torch.randn_like(sample)
nonzero_mask = ((t != 0).float().view(
-1, *([1] * (len(sample.shape) - 1))))
sample = model_mean + nonzero_mask * torch.exp(
0.5 * model_log_variance) * noise
return sample