TensorRT-LLMs/tests/unittest/_torch/helpers.py
QI JUN 524754b6fd
[TRTLLM-8521][chore] remove circular dependency between model engine and cuda graph runner (#7572)
Signed-off-by: junq <22017000+QiJune@users.noreply.github.com>
2025-11-11 10:13:45 -08:00

192 lines
7.2 KiB
Python

from typing import Dict, Tuple
import torch
import torch.nn.functional as F
from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import (
CUDAGraphRunner, CUDAGraphRunnerConfig)
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm.mapping import Mapping
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax /
448.0).view(
x_view.size(0),
x_view.size(2))
def per_token_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), sf
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()
sim = 2 * (x * y).sum() / denominator
return 1 - sim
def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype):
# align with woq_assert_near_eq function in tests/unittest/trt/quantization/_utils.py
if weight_dtype == torch.int8:
bits_in_type = 8
elif weight_dtype == torch.quint4x2:
bits_in_type = 4
quant_range_scale = 1.0 / float(1 << (bits_in_type - 1))
max_val = torch.max(abs(x)).item()
atol = (max_val * quant_range_scale) * 1.5 # allow for rounding
return atol
def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor,
final_scales: torch.Tensor, num_experts: int,
weights: Dict[str, torch.Tensor]) -> torch.Tensor:
# cast back to the input dtype
results = torch.zeros_like(x)
# naive looping over experts
for expert_id in range(num_experts):
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
w1_weight = weights[f"{expert_id}.w1.weight"]
w2_weight = weights[f"{expert_id}.w2.weight"]
w3_weight = weights[f"{expert_id}.w3.weight"]
expert_inputs = x[batch_idx]
output = (F.silu(expert_inputs @ w1_weight.t()) *
(expert_inputs @ w3_weight.t())) @ w2_weight.t()
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
return results.view_as(x)
def reference_block_scale_moe_torch(
x: torch.Tensor, selected_experts: torch.Tensor,
final_scales: torch.Tensor, num_experts: int,
weights: Dict[str, torch.Tensor]) -> torch.Tensor:
results = torch.zeros_like(x)
# naive looping over experts
for expert_id in range(num_experts):
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
w1 = weights[f"{expert_id}.w1.weight"]
w2 = weights[f"{expert_id}.w2.weight"]
w3 = weights[f"{expert_id}.w3.weight"]
w1_fp8, w1_scale = per_block_cast_to_fp8(w1)
w2_fp8, w2_scale = per_block_cast_to_fp8(w2)
w3_fp8, w3_scale = per_block_cast_to_fp8(w3)
x_fp8, x_scale = per_token_cast_to_fp8(x[batch_idx])
def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor,
mat_b: torch.Tensor, mat_scale_b: torch.Tensor):
shape_m, shape_k = mat_a.shape
shape_n = mat_b.shape[0]
result = torch.zeros((shape_m, shape_n), dtype=torch.float32).cuda()
for m in range(shape_m):
for n in range(shape_n):
for k in range(0, shape_k, 128):
scale_factor = mat_scale_a[m, k //
128] * mat_scale_b[n // 128,
k // 128]
tile_a = mat_a[m, k:k + 128]
tile_b = mat_b[n, k:k + 128]
tile_d = torch.dot(tile_a.float(), tile_b.float())
result[
m,
n] += scale_factor.cuda() * tile_d.cuda().float()
result_bf16 = result.bfloat16()
return result_bf16
# gemm1
fc3_output = block_scale_gemm(x_fp8, x_scale, w1_fp8, w1_scale)
gate_output = F.silu(fc3_output)
fc1_output = block_scale_gemm(x_fp8, x_scale, w3_fp8, w3_scale)
act_output = gate_output * fc1_output
# gemm2
act_fp8, act_scale = per_token_cast_to_fp8(act_output)
output = block_scale_gemm(act_fp8, act_scale, w2_fp8, w2_scale)
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
return results.view_as(x)
def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False):
config = CUDAGraphRunnerConfig(
use_cuda_graph=True,
cuda_graph_padding_enabled=False,
cuda_graph_batch_sizes=[batch_size],
max_cuda_graph_batch_size=batch_size,
batch_size=batch_size,
max_beam_width=1,
max_num_tokens=1,
use_mrope=use_mrope,
spec_config=None,
cuda_graph_mem_pool=None,
enable_attention_dp=False,
original_max_draft_len=0,
original_max_total_draft_tokens=0,
is_draft_model=False,
mapping=Mapping(),
dist=None,
kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER)
return CUDAGraphRunner(config)