mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
* move some models to examples/models/contrib Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * update the document Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove arctic, blip2, cogvlm, dbrx from qa test list Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove tests of dit, mmdit and stdit from qa test Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * remove grok, jais, sdxl, skywork, smaug from qa test list Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * re-organize the glm examples Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix issues after running pre-commit Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix some typo in glm_4_9b readme Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> * fix bug Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com> --------- Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
117 lines
5.1 KiB
Python
117 lines
5.1 KiB
Python
import numpy as np
|
|
import torch
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
def space_timesteps(num_timesteps, timestep_respacing):
|
|
if num_timesteps < timestep_respacing:
|
|
raise ValueError(
|
|
f"cannot divide section of {num_timesteps} steps into {timestep_respacing}"
|
|
)
|
|
frac_stride = (num_timesteps - 1) / (timestep_respacing - 1)
|
|
cur_idx = 0.0
|
|
taken_steps = []
|
|
for _ in range(timestep_respacing):
|
|
taken_steps.append(round(cur_idx))
|
|
cur_idx += frac_stride
|
|
return set(taken_steps)
|
|
|
|
|
|
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
|
res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
|
while len(res.shape) < len(broadcast_shape):
|
|
res = res[..., None]
|
|
return res + torch.zeros(broadcast_shape, device=timesteps.device)
|
|
|
|
|
|
class DiTDiffusionPipeline:
|
|
|
|
def __init__(self, dit_model, timestep_respacing, diffusion_steps=1000):
|
|
self.dit_model = dit_model
|
|
self.timesteps = space_timesteps(diffusion_steps, timestep_respacing)
|
|
self.timestep_map = []
|
|
self.original_num_steps = diffusion_steps
|
|
betas = self._setup_betas()
|
|
self.betas = betas
|
|
self.num_timesteps = int(betas.shape[0])
|
|
|
|
alphas = 1.0 - betas
|
|
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
|
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
|
assert self.alphas_cumprod_prev.shape == (self.num_timesteps, )
|
|
|
|
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
|
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod -
|
|
1)
|
|
|
|
self.posterior_variance = (betas * (1.0 - self.alphas_cumprod_prev) /
|
|
(1.0 - self.alphas_cumprod))
|
|
|
|
self.posterior_log_variance_clipped = np.log(
|
|
np.append(self.posterior_variance[1],
|
|
self.posterior_variance[1:])) if len(
|
|
self.posterior_variance) > 1 else np.array([])
|
|
|
|
self.posterior_mean_coef1 = (betas * np.sqrt(self.alphas_cumprod_prev) /
|
|
(1.0 - self.alphas_cumprod))
|
|
self.posterior_mean_coef2 = ((1.0 - self.alphas_cumprod_prev) *
|
|
np.sqrt(alphas) /
|
|
(1.0 - self.alphas_cumprod))
|
|
|
|
def _setup_betas(self):
|
|
scale = 1000 / self.original_num_steps
|
|
betas = np.linspace(scale * 0.0001,
|
|
scale * 0.02,
|
|
self.original_num_steps,
|
|
dtype=np.float64)
|
|
last_alpha_cumprod = 1.0
|
|
new_betas = []
|
|
alphas = 1.0 - betas
|
|
alphas_cumprod = np.cumprod(alphas, axis=0)
|
|
for i, alpha_cumprod in enumerate(alphas_cumprod):
|
|
if i in self.timesteps:
|
|
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
|
last_alpha_cumprod = alpha_cumprod
|
|
self.timestep_map.append(i)
|
|
betas = np.array(new_betas)
|
|
return betas
|
|
|
|
def run(self, sample, labels):
|
|
indices = list(range(self.num_timesteps))[::-1]
|
|
|
|
for i in tqdm(indices):
|
|
t = torch.tensor([i] * sample.shape[0], device=sample.device)
|
|
B, C = sample.shape[:2]
|
|
assert t.shape == (B, )
|
|
map_tensor = torch.tensor(self.timestep_map,
|
|
device=t.device,
|
|
dtype=t.dtype)
|
|
model_output = self.dit_model(sample, map_tensor[t], labels)
|
|
assert model_output.shape == (B, C * 2, *sample.shape[2:])
|
|
model_output, model_var_values = torch.split(model_output, C, dim=1)
|
|
min_log = _extract_into_tensor(self.posterior_log_variance_clipped,
|
|
t, sample.shape)
|
|
max_log = _extract_into_tensor(np.log(self.betas), t, sample.shape)
|
|
|
|
frac = (model_var_values + 1) / 2
|
|
model_log_variance = frac * max_log + (1 - frac) * min_log
|
|
pred_xstart = (
|
|
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t,
|
|
sample.shape) * sample -
|
|
_extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t,
|
|
sample.shape) * model_output)
|
|
|
|
model_mean = (_extract_into_tensor(self.posterior_mean_coef1, t,
|
|
sample.shape) * pred_xstart +
|
|
_extract_into_tensor(self.posterior_mean_coef2, t,
|
|
sample.shape) * sample)
|
|
|
|
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == sample.shape
|
|
noise = torch.randn_like(sample)
|
|
nonzero_mask = ((t != 0).float().view(
|
|
-1, *([1] * (len(sample.shape) - 1))))
|
|
sample = model_mean + nonzero_mask * torch.exp(
|
|
0.5 * model_log_variance) * noise
|
|
return sample
|