from typing import Dict, Tuple import torch import torch.nn.functional as F from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import ( CUDAGraphRunner, CUDAGraphRunnerConfig) from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType from tensorrt_llm.mapping import Mapping def ceil_div(x: int, y: int) -> int: return (x + y - 1) // y def align(x: int, y: int) -> int: return ceil_div(x, y) * y def ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) return (x_view * (448.0 / x_amax.unsqueeze(2))).to( torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1) def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros((ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view( x_view.size(0), x_view.size(2)) def per_token_cast_to_fp8_e8m0( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 and x.size(1) % 128 == 0 m, n = x.shape x_view = x.view(m, -1, 128) x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) sf = ceil_to_ue8m0(x_amax / 448.0) return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view( m, n), sf def per_block_cast_to_fp8_e8m0( x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: assert x.dim() == 2 m, n = x.shape x_padded = torch.zeros((align(m, 128), align(n, 128)), dtype=x.dtype, device=x.device) x_padded[:m, :n] = x x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128) x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) sf = ceil_to_ue8m0(x_amax / 448.0) x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view( x_view.size(0), x_view.size(2)) def calc_diff(x, y): x, y = x.double(), y.double() denominator = (x * x + y * y).sum() sim = 2 * (x * y).sum() / denominator return 1 - sim def calc_woq_tolerence(x: torch.Tensor, weight_dtype: torch.dtype): # align with woq_assert_near_eq function in tests/unittest/trt/quantization/_utils.py if weight_dtype == torch.int8: bits_in_type = 8 elif weight_dtype == torch.quint4x2: bits_in_type = 4 quant_range_scale = 1.0 / float(1 << (bits_in_type - 1)) max_val = torch.max(abs(x)).item() atol = (max_val * quant_range_scale) * 1.5 # allow for rounding return atol def reference_moe_torch(x: torch.Tensor, selected_experts: torch.Tensor, final_scales: torch.Tensor, num_experts: int, weights: Dict[str, torch.Tensor], apply_routing_on_input: bool = False) -> torch.Tensor: # cast back to the input dtype results = torch.zeros_like(x) # naive looping over experts for expert_id in range(num_experts): batch_idx, nth_expert = torch.where(selected_experts == expert_id) w1_weight = weights[f"{expert_id}.w1.weight"] w2_weight = weights[f"{expert_id}.w2.weight"] w3_weight = weights[f"{expert_id}.w3.weight"] expert_inputs = x[batch_idx] if apply_routing_on_input: expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, None] output = (F.silu(expert_inputs @ w1_weight.t()) * (expert_inputs @ w3_weight.t())) @ w2_weight.t() if not apply_routing_on_input: output = output * final_scales[batch_idx, nth_expert, None] results[batch_idx] += output return results.view_as(x) def reference_bmm_moe_torch( x: torch.Tensor, selected_experts: torch.Tensor, final_scales: torch.Tensor, w3_w1_stacked_weight: torch.Tensor, w2_stacked_weight: torch.Tensor, apply_routing_on_input: bool = True) -> torch.Tensor: """Reference for stacked MoE in TRT-LLM format. Args: x: (seq_len, hidden_size) selected_experts: (seq_len, topk) final_scales: (seq_len, topk) w3_w1_stacked_weight: (num_experts, 2*intermediate_size, hidden_size) - TRT-LLM format w2_stacked_weight: (num_experts, hidden_size, intermediate_size) - TRT-LLM format """ num_experts = w3_w1_stacked_weight.shape[0] intermediate_size = w3_w1_stacked_weight.shape[1] // 2 results = torch.zeros_like(x) # Loop over experts (matches reference_moe_torch pattern) for expert_id in range(num_experts): batch_idx, nth_expert = torch.where(selected_experts == expert_id) if len(batch_idx) == 0: continue # Get weights for this expert (TRT-LLM format) gate_up = w3_w1_stacked_weight[expert_id] # (2*I, H) w3_weight = gate_up[:intermediate_size, :] # (I, H) w1_weight = gate_up[intermediate_size:, :] # (I, H) w2_weight = w2_stacked_weight[expert_id] # (H, I) expert_inputs = x[batch_idx] if apply_routing_on_input: expert_inputs = expert_inputs * final_scales[batch_idx, nth_expert, None] # Gated MLP computation - TRT-LLM format uses .t() output = (F.silu(expert_inputs @ w1_weight.t()) * (expert_inputs @ w3_weight.t())) @ w2_weight.t() if not apply_routing_on_input: output = output * final_scales[batch_idx, nth_expert, None] results[batch_idx] += output return results.view_as(x) def reference_block_scale_moe_torch( x: torch.Tensor, selected_experts: torch.Tensor, final_scales: torch.Tensor, num_experts: int, weights: Dict[str, torch.Tensor]) -> torch.Tensor: results = torch.zeros_like(x) # naive looping over experts for expert_id in range(num_experts): batch_idx, nth_expert = torch.where(selected_experts == expert_id) w1 = weights[f"{expert_id}.w1.weight"] w2 = weights[f"{expert_id}.w2.weight"] w3 = weights[f"{expert_id}.w3.weight"] w1_fp8, w1_scale = per_block_cast_to_fp8(w1) w2_fp8, w2_scale = per_block_cast_to_fp8(w2) w3_fp8, w3_scale = per_block_cast_to_fp8(w3) x_fp8, x_scale = per_token_cast_to_fp8(x[batch_idx]) def block_scale_gemm(mat_a: torch.Tensor, mat_scale_a: torch.Tensor, mat_b: torch.Tensor, mat_scale_b: torch.Tensor): shape_m, shape_k = mat_a.shape shape_n = mat_b.shape[0] result = torch.zeros((shape_m, shape_n), dtype=torch.float32).cuda() for m in range(shape_m): for n in range(shape_n): for k in range(0, shape_k, 128): scale_factor = mat_scale_a[m, k // 128] * mat_scale_b[n // 128, k // 128] tile_a = mat_a[m, k:k + 128] tile_b = mat_b[n, k:k + 128] tile_d = torch.dot(tile_a.float(), tile_b.float()) result[ m, n] += scale_factor.cuda() * tile_d.cuda().float() result_bf16 = result.bfloat16() return result_bf16 # gemm1 fc3_output = block_scale_gemm(x_fp8, x_scale, w1_fp8, w1_scale) gate_output = F.silu(fc3_output) fc1_output = block_scale_gemm(x_fp8, x_scale, w3_fp8, w3_scale) act_output = gate_output * fc1_output # gemm2 act_fp8, act_scale = per_token_cast_to_fp8(act_output) output = block_scale_gemm(act_fp8, act_scale, w2_fp8, w2_scale) results[batch_idx] += final_scales[batch_idx, nth_expert, None] * output return results.view_as(x) def create_mock_cuda_graph_runner(batch_size: int, use_mrope: bool = False): config = CUDAGraphRunnerConfig( use_cuda_graph=True, cuda_graph_padding_enabled=False, cuda_graph_batch_sizes=[batch_size], max_cuda_graph_batch_size=batch_size, batch_size=batch_size, max_beam_width=1, max_num_tokens=1, use_mrope=use_mrope, spec_config=None, cuda_graph_mem_pool=None, enable_attention_dp=False, original_max_draft_len=0, original_max_total_draft_tokens=0, is_draft_model=False, mapping=Mapping(), dist=None, kv_cache_manager_key=ResourceManagerType.KV_CACHE_MANAGER) return CUDAGraphRunner(config)