mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-28 14:44:24 +08:00
117 lines
4.6 KiB
Python
117 lines
4.6 KiB
Python
from typing import Dict, Tuple
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def ceil_div(x: int, y: int) -> int:
|
|
return (x + y - 1) // y
|
|
|
|
|
|
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.dim() == 2 and x.size(1) % 128 == 0
|
|
m, n = x.shape
|
|
x_view = x.view(m, -1, 128)
|
|
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
|
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(
|
|
torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
|
|
|
|
|
|
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
assert x.dim() == 2
|
|
m, n = x.shape
|
|
x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128),
|
|
dtype=x.dtype,
|
|
device=x.device)
|
|
x_padded[:m, :n] = x
|
|
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
|
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
|
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
|
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax /
|
|
448.0).view(
|
|
x_view.size(0),
|
|
x_view.size(2))
|
|
|
|
|
|
def calc_diff(x, y):
|
|
x, y = x.double(), y.double()
|
|
denominator = (x * x + y * y).sum()
|
|
sim = 2 * (x * y).sum() / denominator
|
|
return 1 - sim
|
|
|
|
|
|
def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor,
|
|
final_scales: torch.Tensor, num_experts: int,
|
|
weights: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
# cast back to the input dtype
|
|
results = torch.zeros_like(x)
|
|
|
|
# naive looping over experts
|
|
for expert_id in range(num_experts):
|
|
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
|
|
w1_weight = weights[f"{expert_id}.w1.weight"]
|
|
w2_weight = weights[f"{expert_id}.w2.weight"]
|
|
w3_weight = weights[f"{expert_id}.w3.weight"]
|
|
expert_inputs = x[batch_idx]
|
|
output = (F.silu(expert_inputs @ w1_weight.t()) *
|
|
(expert_inputs @ w3_weight.t())) @ w2_weight.t()
|
|
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
|
|
|
|
return results.view_as(x)
|
|
|
|
|
|
def reference_block_scale_moe_torch(
|
|
x: torch.Tensor, selected_experts: torch.Tensor,
|
|
final_scales: torch.Tensor, num_experts: int,
|
|
weights: Dict[str, torch.Tensor]) -> torch.Tensor:
|
|
results = torch.zeros_like(x)
|
|
|
|
# naive looping over experts
|
|
for expert_id in range(num_experts):
|
|
batch_idx, nth_expert = torch.where(selected_experts == expert_id)
|
|
w1 = weights[f"{expert_id}.w1.weight"]
|
|
w2 = weights[f"{expert_id}.w2.weight"]
|
|
w3 = weights[f"{expert_id}.w3.weight"]
|
|
|
|
w1_fp8, w1_scale = per_block_cast_to_fp8(w1)
|
|
w2_fp8, w2_scale = per_block_cast_to_fp8(w2)
|
|
w3_fp8, w3_scale = per_block_cast_to_fp8(w3)
|
|
|
|
x_fp8, x_scale = per_token_cast_to_fp8(x[batch_idx])
|
|
|
|
def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor,
|
|
mat_b: torch.Tensor, mat_scale_b: torch.Tensor):
|
|
shape_m, shape_k = mat_a.shape
|
|
shape_n = mat_b.shape[0]
|
|
result = torch.zeros((shape_m, shape_n), dtype=torch.float32).cuda()
|
|
|
|
for m in range(shape_m):
|
|
for n in range(shape_n):
|
|
for k in range(0, shape_k, 128):
|
|
scale_factor = mat_scale_a[m, k //
|
|
128] * mat_scale_b[n // 128,
|
|
k // 128]
|
|
tile_a = mat_a[m, k:k + 128]
|
|
tile_b = mat_b[n, k:k + 128]
|
|
tile_d = torch.dot(tile_a.float(), tile_b.float())
|
|
result[
|
|
m,
|
|
n] += scale_factor.cuda() * tile_d.cuda().float()
|
|
|
|
result_bf16 = result.bfloat16()
|
|
|
|
return result_bf16
|
|
|
|
# gemm1
|
|
fc3_output = block_scale_gemm(x_fp8, x_scale, w1_fp8, w1_scale)
|
|
gate_output = F.silu(fc3_output)
|
|
fc1_output = block_scale_gemm(x_fp8, x_scale, w3_fp8, w3_scale)
|
|
act_output = gate_output * fc1_output
|
|
# gemm2
|
|
act_fp8, act_scale = per_token_cast_to_fp8(act_output)
|
|
output = block_scale_gemm(act_fp8, act_scale, w2_fp8, w2_scale)
|
|
|
|
results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output
|
|
|
|
return results.view_as(x)
|