Deepseek R1 FP8 Support on Blackwell (#6486)

Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com>
Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com>
Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-08-01 10:26:28 +08:00 committed by GitHub
parent 8c165fd27a
commit 7bb0a78631
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1486 additions and 47 deletions

View File

@ -50,7 +50,10 @@ def add_llm_args(parser):
parser.add_argument('--moe_backend',
type=str,
default='CUTLASS',
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
choices=[
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
'DEEPGEMM', 'CUTEDSL'
])
parser.add_argument('--enable_attention_dp',
default=False,
action='store_true')

View File

@ -61,4 +61,5 @@ etcd3
blake3
llguidance==0.7.29
soundfile
deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a
triton==3.3.1

View File

@ -12,7 +12,8 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
BaseWeightLoader
from tensorrt_llm._torch.models.modeling_utils import (
register_checkpoint_weight_loader, run_concurrently)
from tensorrt_llm._utils import local_mpi_rank, local_mpi_size
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
local_mpi_size)
from tensorrt_llm.logger import logger
@ -38,6 +39,8 @@ class HfWeightLoader(BaseWeightLoader):
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
)
self.prefetch_files(weight_files)
# Ensure that all local ranks have finished prefetching before loading weights
local_mpi_barrier()
return self._load_weights_in_parallel(
weight_files, self._load_safetensors_file,

View File

@ -38,12 +38,15 @@ from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from tensorrt_llm import logger
from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
from ..attention_backend import AttentionMetadata
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
@ -1244,7 +1247,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
local_num_heads = num_heads // weight_divisor
k_nope_weight_trans = k_nope_weight.transpose(2, 1)
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()
kv_b_proj = torch.concat([
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
@ -1256,11 +1259,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
return kv_b_proj, k_nope_weight_trans
def check_weight_dtype(module_name: str, dtype):
weight_name = "weight"
w_dtype = weights[f"{module_name}.{weight_name}"].dtype
return w_dtype == dtype
def load_kv_b_proj_and_k_b_proj_trans_dequant(
module_name: str) -> torch.Tensor:
weight_name = "weight"
@ -1290,7 +1288,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
local_num_heads = num_heads // weight_divisor
k_nope_weight_trans = k_nope_weight.transpose(2, 1)
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()
kv_b_proj = torch.concat([
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
@ -1333,6 +1331,21 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
all_named_modules = dict(self.named_modules())
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_version() == 100:
for name in list(weights.keys()):
# Use ".experts." to exclude shared_experts.
if name.endswith(
"weight_scale_inv") and ".experts." not in name:
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
weights[weight_name] = weights[weight_name].cpu()
weights[name] = weights[name].cpu()
for name, module in tqdm(all_named_modules.items(),
desc="Loading weights"):
if len(module._parameters) > 0:
@ -1384,6 +1397,26 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
attn_module.v_b_proj_scale = nn.Parameter(
v_b_proj_scale, requires_grad=False)
if attn_module.k_b_proj_trans_dequant is not None:
attn_module.k_b_proj_trans_dequant.data.copy_(
weight_dequant(
k_b_proj_trans.view(
-1, k_b_proj_trans.shape[-1]).cuda(),
k_b_proj_trans_scale.view(
-1,
k_b_proj_trans_scale.shape[-1]).cuda(),
).view(
*attn_module.k_b_proj_trans_dequant.shape).
to(attn_module.k_b_proj_trans_dequant.dtype))
if attn_module.v_b_proj_dequant is not None:
attn_module.v_b_proj_dequant.data.copy_(
weight_dequant(
v_b_proj.view(-1,
v_b_proj.shape[-1]).cuda(),
v_b_proj_scale.view(
-1, v_b_proj_scale.shape[-1]).cuda(),
).view(*attn_module.v_b_proj_dequant.shape).to(
attn_module.v_b_proj_dequant.dtype))
elif names[-1] == "kv_a_proj_with_mqa":
fused_a = weights[
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
@ -1431,6 +1464,18 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
for n, p in module.named_parameters():
p.data.copy_(module_weights[n][:])
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
) and get_sm_version() == 100 and hasattr(
module, "weight_scale"):
transfromed_scale = transform_sf_into_required_layout(
module.weight_scale,
mn=module.weight.shape[0],
k=module.weight.shape[1],
recipe=(1, 128, 128),
is_sfa=False)
module.weight_scale = nn.Parameter(transfromed_scale,
requires_grad=False)
for idx, layer in enumerate(
self.model.layers[:self.config.num_hidden_layers]):
if idx == self.config.num_hidden_layers - 1:

View File

@ -357,6 +357,7 @@ def fp8_block_scaling_bmm_out(
mat2_fp8: torch.Tensor,
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
@ -365,30 +366,33 @@ def fp8_block_scaling_bmm_out(
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale, out)
elif sm_version == 100:
low_latency = True
use_deep_seek_fp8 = True
tile_size = 8
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
m_size = mat1.shape[0]
if m_size % tile_size != 0:
tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
mat1 = torch.nn.functional.pad(
mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)
output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2))
out.copy_(output)
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
mat1_fp8,
mat2_fp8,
tile_size=tile_size,
epilogue_tile_m=epilogue_tile_m,
use_deep_seek_fp8=use_deep_seek_fp8,
low_latency=low_latency,
dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
dq_sfs_b=mat2_scale,
out_dtype=out.dtype,
)
out.copy_(output[:, :m_size])
# low_latency = True
# use_deep_seek_fp8 = True
# tile_size = 8
# epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
# m_size = mat1.shape[0]
# if m_size % tile_size != 0:
# tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
# mat1 = torch.nn.functional.pad(
# mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)
# mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
# mat1)
# output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
# mat1_fp8,
# mat2_fp8,
# tile_size=tile_size,
# epilogue_tile_m=epilogue_tile_m,
# use_deep_seek_fp8=use_deep_seek_fp8,
# low_latency=low_latency,
# dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
# dq_sfs_b=mat2_scale,
# out_dtype=out.dtype,
# )
# out.copy_(output[:, :m_size])
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
@ -676,6 +680,8 @@ class MLA(nn.Module):
requires_grad=False,
)
self.k_b_proj_trans_dequant = None
self.v_b_proj_dequant = None
if has_fp8_block_scales:
self.k_b_proj_trans_scale = nn.Parameter(
torch.empty(
@ -701,6 +707,23 @@ class MLA(nn.Module):
),
requires_grad=False,
)
if get_sm_version() == 100:
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.kv_lora_rank,
self.qk_nope_head_dim),
dtype=self.dtype,
),
requires_grad=False,
)
self.v_b_proj_dequant = nn.Parameter(
torch.empty(
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
dtype=self.dtype,
),
requires_grad=False,
)
else:
self.k_b_proj_trans_scale = None
self.v_b_proj_scale = None
@ -1197,8 +1220,13 @@ class MLA(nn.Module):
# [num_heads, num_tokens, self.kv_lora_rank]
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans,
self.k_b_proj_trans_scale, q_nope_out)
fp8_block_scaling_bmm_out(
q_nope,
self.k_b_proj_trans,
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
@ -1247,9 +1275,13 @@ class MLA(nn.Module):
self.v_b_proj.transpose(1, 2),
attn_output.transpose(0, 1))
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1))
fp8_block_scaling_bmm_out(
attn_out_latent,
self.v_b_proj,
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
)
else:
raise NotImplementedError(
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")

View File

@ -8,6 +8,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig
from ...model_config import ModelConfig
from .fused_moe_cute_dsl import CuteDslFusedMoE
from .fused_moe_cutlass import CutlassFusedMoE
from .fused_moe_deepgemm import DeepGemmFusedMoE
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_wide_ep import WideEPMoE
@ -31,6 +32,8 @@ def get_moe_cls(
return VanillaMoE
elif moe_backend.upper() == "CUTEDSL":
return CuteDslFusedMoE
elif moe_backend.upper() == "DEEPGEMM":
return DeepGemmFusedMoE
elif moe_backend.upper() == "TRTLLM":
if quant_config is not None and (
quant_config.quant_mode.has_fp8_block_scales()
@ -139,5 +142,19 @@ def create_moe(
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
elif moe_cls == DeepGemmFusedMoE:
return moe_cls(
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
else:
raise ValueError(f"Unsupported moe backend: {moe_cls}")

View File

@ -0,0 +1,483 @@
from typing import List, Optional, Union
import deep_gemm
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._utils import nvtx_range
from ...distributed import allgather
from ...model_config import ModelConfig
from ...utils import Fp4QuantizedTensor
from .fused_moe_cutlass import CutlassFusedMoE
from .quantization import MoEWeightLoadingMode
from .routing import BaseMoeRoutingMethod
@triton.jit
def _masked_index_copy_group_quant_fp8(
input_ptr,
out_q_ptr,
out_s_ptr,
# mask indices
start_offsets_ptr,
row_indices_ptr,
# dimensions
row_size,
col_size,
dim_size,
group_size,
# output scale factor size
aligned_col,
aligned_dim,
# quantization parameters
eps,
fp8_max,
# block size
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
group_block = tl.program_id(0)
token_block = tl.program_id(1)
token_block_num = tl.num_programs(1)
# calculate group and element offsets
num_tokens = tl.load(start_offsets_ptr + row_size)
elem_offsets = group_block * group_size * 4 + tl.arange(0, BLOCK)
output_s_offs = out_s_ptr + group_block * aligned_col
# process tokens
for token_index in tl.range(token_block,
num_tokens,
token_block_num,
num_stages=NUM_STAGE):
# load indices
row_idx = tl.load(row_indices_ptr + token_index)
start_offset = tl.load(start_offsets_ptr + row_idx)
idx = row_idx * col_size + token_index - start_offset
idx_s = row_idx * aligned_dim * aligned_col + token_index - start_offset
output_s_int32 = 0
for group_index in tl.range(4):
# load input data
dim_offset = elem_offsets + group_index * group_size
valid = dim_offset < dim_size
input_data = tl.load(input_ptr + token_index * dim_size +
dim_offset,
mask=valid,
other=0.0)
# quantization
_absmax = tl.maximum(tl.max(tl.abs(input_data)), eps)
output_s = _absmax / fp8_max
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(input_data / output_s, -fp8_max,
fp8_max).to(out_q_ptr.dtype.element_ty)
output_s = output_s.to(tl.int32, bitcast=True) >> 23
output_s_int32 += output_s << (group_index * 8)
# store quantized values and scaling factor
tl.store(out_q_ptr + idx * dim_size + dim_offset,
output_q,
mask=valid)
tl.store(output_s_offs + idx_s, output_s_int32)
def masked_index_copy_group_quant_fp8(
output: torch.Tensor,
input: torch.Tensor,
start_offsets: torch.Tensor,
row_indices: torch.Tensor,
group_size: int,
eps: float = 1e-10,
):
assert (
input.shape[-1] % group_size == 0
), "the last dimension of `input` cannot be divisible by `group_size`"
assert input.is_contiguous(), "`input` is not contiguous"
assert input.ndim == 2, "Input must be a 2D tensor"
assert output.ndim == 3, "Output must be a 3D tensor, [row, col, dim]"
assert start_offsets.shape[
0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)"
num_tokens = input.shape[0]
row_size = output.shape[0]
col_size = output.shape[1]
dim_size = output.shape[2]
# create padded output_s
alignment = 4
scale_dim = (dim_size + group_size - 1) // group_size
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
padded_col_size = (col_size + alignment - 1) // alignment * alignment
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
dtype=torch.int32,
device='cuda')
# get block/grid/stage/warp
num_groups = (dim_size + group_size - 1) // group_size
BLOCK = group_size
if num_tokens <= 1000 or col_size <= 256: # Small workload
TOKEN_BLOCK_NUM = 256
NUM_STAGES = 4
num_warps = 2
elif num_tokens <= 10000 or col_size <= 2048: # Medium workload
TOKEN_BLOCK_NUM = 1024
NUM_STAGES = 2
num_warps = 1
else: # Large workload
TOKEN_BLOCK_NUM = 2048
NUM_STAGES = 2
num_warps = 1
grid = (
(num_groups + 3) // 4,
TOKEN_BLOCK_NUM,
)
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
_masked_index_copy_group_quant_fp8[grid](
input,
output,
output_s,
start_offsets,
row_indices,
row_size,
col_size,
dim_size,
group_size,
padded_col_size,
padded_dim_size // 4,
eps,
fp8_max,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
)
output_s = output_s.transpose(1, 2)[:, :col_size, :]
return output_s
@triton.jit
def masked_index_gather_kernel(output_ptr, input_ptr, start_offsets_ptr,
row_indices_ptr, row_size, col_size, dim_size,
BLOCK_SIZE: tl.constexpr):
# get program id and block offset
pid = tl.program_id(0)
num_tokens = tl.load(start_offsets_ptr + row_size)
token_idx = pid
valid_token = token_idx < num_tokens
if not valid_token:
return
row_idx = tl.load(row_indices_ptr + token_idx)
start_offset = tl.load(start_offsets_ptr + row_idx)
col_idx = token_idx - start_offset
# Process elements in blocks
for hidden_start in tl.range(0, dim_size, BLOCK_SIZE):
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
valid_hidden = hidden_indices < dim_size
input_offset = row_idx * col_size * dim_size + col_idx * dim_size + hidden_indices
input_val = tl.load(input_ptr + input_offset,
mask=valid_hidden,
other=0.0)
output_offset = pid * dim_size + hidden_indices
tl.store(output_ptr + output_offset, input_val, mask=valid_hidden)
@torch.no_grad()
def triton_masked_index_gather(output, input, start_offsets, row_indices):
assert output.ndim == 2, "Output must be a 2D tensor"
assert input.ndim == 3, "Input must be a 3D tensor, [row, col, dim]"
assert start_offsets.shape[
0] == input.shape[0] + 1, "Start offsets must be (num_experts + 1)"
row_size = input.shape[0]
col_size = input.shape[1]
dim_size = input.shape[2]
num_tokens = output.shape[0]
grid = (num_tokens, )
# launch kernel
masked_index_gather_kernel[grid](output,
input,
start_offsets,
row_indices,
row_size,
col_size,
dim_size,
BLOCK_SIZE=1024)
return
@nvtx_range("[DG] act")
@torch.compile(dynamic=True)
def swiglu_fused_moe(x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
@nvtx_range("[DG] indexing")
@torch.compile(dynamic=True)
def indexing(x, mask):
return x[mask > 0, :].contiguous()
@nvtx_range("[DG] preprocess_after_permute")
def preprocess_after_permute(expert_first_token_offset_tensor,
permuted_data_tensor):
# get tokens per expert
masked_m = expert_first_token_offset_tensor[
1:] - expert_first_token_offset_tensor[:-1]
token_to_expert_map = torch.searchsorted(
expert_first_token_offset_tensor[1:],
torch.arange(permuted_data_tensor.shape[0], device='cuda'),
right=True)
return masked_m.to(torch.int32), token_to_expert_map
@nvtx_range("[DG]")
def deepgemm_fp8_group_blockwise_gemm(
a: torch.Tensor,
b: torch.Tensor,
sfa: torch.Tensor,
sfb: torch.Tensor,
masked_m: torch.Tensor,
expected_m: int,
) -> torch.Tensor:
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
assert a.stride(-1) == 1
assert b.stride(-1) == 1
assert masked_m.is_contiguous()
num_groups, m, k = a.shape
num_groups_, n, k_ = b.shape
num_groups__, m_, n_ = d.shape
num_groups___ = masked_m.numel()
# Type and shape checks
assert num_groups == num_groups_ == num_groups__ == num_groups___
assert m == m_ and n == n_ and k == k_
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
assert a.dtype == torch.float8_e4m3fn
assert b.dtype == torch.float8_e4m3fn
assert d.dtype == torch.bfloat16
assert masked_m.dtype == torch.int32
# D must be N-major
assert d.stride(-1) == 1
# Transform SFA and SFB into compute-required layout
deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb),
d,
masked_m,
expected_m,
disable_ue8m0_cast=True)
return d
class DeepGemmFusedMoE(CutlassFusedMoE):
"""
Python Flow of Fused Mixture of Experts (MoE) Layer.
Args:
num_experts (int): Number of experts in the MoE layer.
top_k (int): Number of top experts to select for each input token.
hidden_size (int): Size of the hidden state.
intermediate_size (int): Size of the intermediate state.
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
This backend is composed of multiple custom ops:
1. moe_permute_op: permute the input tensor and the expert selected tensor.
2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm.
3. moe_finalize_scale_op: finalize the scale of the output tensor.
"""
def __init__(
self,
*,
routing_method: BaseMoeRoutingMethod,
num_experts: int,
hidden_size: int,
intermediate_size: int,
dtype: Optional[torch.dtype] = None,
reduce_results: bool = False,
model_config: ModelConfig = ModelConfig(),
aux_stream: Optional[torch.cuda.Stream] = None,
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
VANILLA,
apply_router_weight_on_input: bool = False,
layer_idx: Optional[int] = None,
):
super().__init__(
routing_method=routing_method,
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
dtype=dtype,
reduce_results=reduce_results,
model_config=model_config,
aux_stream=aux_stream,
weight_loading_mode=weight_loading_mode,
apply_router_weight_on_input=apply_router_weight_on_input,
layer_idx=layer_idx,
)
@nvtx_range("[DG] forward")
def forward_chunk(
self,
x: Union[torch.Tensor, Fp4QuantizedTensor],
router_logits: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
all_rank_num_tokens: Optional[List[int]] = None,
use_dp_padding: Optional[bool] = None,
) -> torch.Tensor:
if isinstance(x, Fp4QuantizedTensor):
assert output_dtype is not None
output_dtype = output_dtype
else:
output_dtype = x.dtype
# apply routing
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape
assert token_selected_experts.shape[0] == router_logits.shape[0]
assert token_final_scales.dtype == torch.float32
assert token_selected_experts.dtype == torch.int32
if self.apply_router_weight_on_input:
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
x = x * token_final_scales.to(x.dtype)
# TODO: remove this once we have correct fusedmoe kernel ready
token_final_scales = None
# quantize inputs
use_deepseek_fp8_block_scale = False
x_sf = None
if self.has_any_quant:
if self.has_deepseek_fp8_block_scales:
use_deepseek_fp8_block_scale = True
else:
raise ValueError(
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
)
use_allgather = self.use_dp and self.parallel_size > 1
if use_allgather:
x, x_sf, token_selected_experts, token_final_scales = allgather(
[x, x_sf, token_selected_experts, token_final_scales],
self.mapping,
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
(
permuted_row_to_unpermuted_row_tensor,
permuted_token_selected_experts_tensor,
permuted_data_tensor,
expert_first_token_offset_tensor,
permuted_token_final_scales_tensor,
unpermuted_row_to_permuted_row_tensor,
) = torch.ops.trtllm.moe_permute_op(
x,
token_selected_experts,
token_final_scales,
None, # w3_w1_weight.view(weight_dtype),
None, # w2_weight.view(weight_dtype),
None, # quant_scales,
input_sf=x_sf,
num_experts_on_rank=self.expert_size_per_partition,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
cluster_size=self.cluster_size,
cluster_rank=self.cluster_rank,
min_latency_mode=False,
use_fp8_block_scaling=use_deepseek_fp8_block_scale,
)
if permuted_data_tensor.numel() == 0:
return torch.zeros_like(x)
masked_m, token_to_expert_map = preprocess_after_permute(
expert_first_token_offset_tensor, permuted_data_tensor)
m_max = (x.shape[0] + 127) // 128 * 128
expected_m = (token_selected_experts.numel() +
self.expert_size_per_partition -
1) // self.expert_size_per_partition
act_input_fp8 = torch.empty(
(self.expert_size_per_partition, m_max, self.hidden_size),
dtype=torch.float8_e4m3fn,
device='cuda')
act_input_sf = masked_index_copy_group_quant_fp8(
act_input_fp8,
permuted_data_tensor,
expert_first_token_offset_tensor,
token_to_expert_map,
group_size=128)
h1 = deepgemm_fp8_group_blockwise_gemm(
a=act_input_fp8,
b=self.w3_w1_weight,
sfa=act_input_sf,
sfb=self.quant_scales[0],
masked_m=masked_m,
expected_m=expected_m,
)
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
h3 = deepgemm_fp8_group_blockwise_gemm(
a=act_input_fp8,
b=self.w2_weight,
sfa=act_input_sf,
sfb=self.quant_scales[1],
masked_m=masked_m,
expected_m=expected_m,
)
triton_masked_index_gather(permuted_data_tensor, h3,
expert_first_token_offset_tensor,
token_to_expert_map)
final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op(
permuted_data_tensor,
None, # biases
token_final_scales,
unpermuted_row_to_permuted_row_tensor,
permuted_row_to_unpermuted_row_tensor,
token_selected_experts,
expert_first_token_offset_tensor,
False, # enable_alltoall
x.shape[0], # num_rows
x.shape[1], # hidden_size
self.routing_method.top_k,
self.expert_size_per_partition, # num_experts_per_node
self.tp_size,
self.tp_rank,
self.ep_size,
self.ep_rank,
)
return final_hidden_states

View File

@ -4,10 +4,13 @@ from typing import Dict, List, NamedTuple, Union
import torch
from torch import nn
from tensorrt_llm import logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.quantization.utils.fp4_utils import (
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices)
from tensorrt_llm.quantization.utils.fp8_utils import (
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
from ..linear import TensorParallelMode, load_weight_shard
from .interface import MoEWeightLoadingMode
@ -463,6 +466,47 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
self.setup_quant_scales(module)
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
weight_loading_mode: MoEWeightLoadingMode):
if get_sm_version() == 100:
expert_ids = set(module.initial_local_expert_ids)
if self.need_load_shared_weights(module):
expert_ids.update(
module.layer_load_balancer.get_load_expert_ids())
for name in list(weights.keys()):
if name.endswith("weight_scale_inv"):
if int(name.split(".")[0]) not in expert_ids:
continue
weight_name = name.replace("weight_scale_inv", "weight")
logger.debug(f"Resmoothing {weight_name}")
weight = weights[weight_name][:]
scale = weights[name][:]
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
weight, scale)
super().load_weights(module, weights, weight_loading_mode)
if get_sm_version() == 100:
transfromed_w3_w1_scale = transform_sf_into_required_layout(
module.quant_scales[0],
mn=module.w3_w1_weight.shape[1],
k=module.w3_w1_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w3_w1_weight_scaling_factor = nn.Parameter(
transfromed_w3_w1_scale, requires_grad=False)
transfromed_w2_scale = transform_sf_into_required_layout(
module.quant_scales[1],
mn=module.w2_weight.shape[1],
k=module.w2_weight.shape[2],
recipe=(1, 128, 128),
num_groups=module.w3_w1_weight.shape[0],
is_sfa=False)
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
requires_grad=False)
self.setup_quant_scales(module)
def setup_quant_scales(self, module: torch.nn.Module):
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
fc_weight_scales=module.w3_w1_weight_scaling_factor,

View File

@ -12,6 +12,7 @@ from torch import nn
from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
@ -20,6 +21,7 @@ from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.mode import QuantAlgo
from ..._utils import get_sm_version
from ...models.modeling_utils import QuantConfig
from ..utils import Fp4QuantizedTensor
@ -570,10 +572,22 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(input)
if get_sm_version() == 100:
import deep_gemm
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
output = torch.empty((input.shape[0], module.weight.shape[0]),
device=input.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((a, a_sf),
(module.weight, module.weight_scale),
output,
disable_ue8m0_cast=True)
else:
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
if bias is not None:
output = output + bias
return output
@ -603,7 +617,6 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
scale_name = self._get_scale_name(weights)
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
@ -614,6 +627,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_fp8_block_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
@ -621,7 +635,6 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
scale_name = self._get_scale_name(weights)
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
@ -629,6 +642,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)

View File

@ -144,6 +144,8 @@ class KvCacheCreator:
end_id=-1)
requests.append(request)
remaining_tokens -= input_seq_len
if self._mapping.enable_attention_dp:
requests = requests * self._mapping.tp_size
return requests
def _get_token_num_for_estimation(self) -> int:

View File

@ -470,6 +470,10 @@ def mpi_comm():
local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST)
def local_mpi_comm():
return local_comm
def mpi_rank():
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
@ -508,6 +512,11 @@ def mpi_barrier():
mpi_comm().Barrier()
def local_mpi_barrier():
if ENABLE_MULTI_DEVICE:
local_comm.Barrier()
def mpi_broadcast(obj, root=0):
return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj

View File

@ -167,7 +167,7 @@ class MoeConfig(StrictBaseModel):
"""
Configuration for MoE.
"""
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM",
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM",
"VANILLA"] = Field(default='CUTLASS',
description="MoE backend to use.")

View File

@ -1,3 +1,3 @@
from . import fp4_utils
from . import fp4_utils, fp8_utils
__all__ = ['fp4_utils']
__all__ = ['fp4_utils', 'fp8_utils']

View File

@ -0,0 +1,530 @@
from typing import Optional, Tuple
import torch
import triton
import triton.language as tl
from tensorrt_llm._utils import nvtx_range
def ceil_div(x: int, y: int) -> int:
"""
Perform ceiling division of two integers.
Args:
x: the dividend.
y: the divisor.
Returns:
The result of the ceiling division.
"""
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
@nvtx_range("[DG] quantization")
@torch.compile(dynamic=True)
def per_token_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
assert x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), sf
else:
assert x.size(2) % 128 == 0
g, m, n = x.shape
x_view = x.view(g, m, -1, 128)
x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view(
g, m, n), sf
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if x.dim() == 2:
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
else:
g, m, n = x.shape
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:, :m, :n] = x
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(1), x_view.size(3))
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
weight = weight.cuda()
sf = sf.cuda()
if weight.dim() == 2:
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:weight.shape[0], :weight.shape[1]]
else:
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
return per_block_cast_to_fp8_e8m0(x)
def get_m_alignment_for_contiguous_layout():
return 128
def get_tma_aligned_size(x: int, element_size: int) -> int:
tma_alignment_bytes = 16
assert tma_alignment_bytes % element_size == 0
alignment = tma_alignment_bytes // element_size
return align(x, alignment)
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
assert x.dtype == torch.float and x.dim() in (2, 3)
# First, convert into UE8M0 `uint8_t`
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
# Second, make padded packed tensors
mn, k = x.shape[-2], x.shape[-1]
remove_dim = False
if x.dim() == 2:
x, remove_dim = x.unsqueeze(0), True
b = x.shape[0]
aligned_mn = get_tma_aligned_size(mn, 4)
aligned_k = align(k, 4)
padded = torch.zeros((b, aligned_mn, aligned_k),
device=x.device,
dtype=torch.uint8)
padded[:, :mn, :k] = ue8m0_tensor
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
aligned_k // 4)
# Finally, transpose
transposed = torch.transpose(
torch.empty((b, aligned_k // 4, aligned_mn),
device=x.device,
dtype=torch.int), 1, 2)
transposed[:, :, :] = padded
aligned_x = transposed[:, :mn, :]
return aligned_x.squeeze(0) if remove_dim else aligned_x
def check_sf_layout(sf: torch.Tensor,
mn: int,
k: int,
gran: Tuple[int, int],
num_groups: Optional[int],
tma_stride_check: bool = False,
type_check: Optional[torch.dtype] = None) -> torch.Tensor:
# Type check
if type_check is not None:
assert sf.dtype == type_check
# Always do shape checks
assert sf.dtype in (torch.float, torch.int)
assert sf.dim() == int(num_groups is not None) + 2
if num_groups is not None:
assert sf.size(-3) == num_groups
assert sf.size(-2) == ceil_div(mn, gran[0])
assert sf.size(-1) == ceil_div(
k, gran[1] * (1 if sf.dtype == torch.float else 4))
# TMA stride checks: TMA aligned and MN-major
if tma_stride_check:
if num_groups is not None:
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
assert sf.stride(-2) == 1
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
return sf
@nvtx_range("[DG] transform_sf_into_required_layout")
def transform_sf_into_required_layout(sf: torch.Tensor,
mn: int,
k: int,
recipe: Tuple[int, int, int],
num_groups: Optional[int] = None,
is_sfa: bool = False):
gran = (recipe[0 if is_sfa else 1], recipe[2])
should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128))
or (sf.dtype == torch.int and gran == (128, 128)))
if not should_skip_transform:
# Pre-transform checks
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (1, 128):
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
if sf.dtype == torch.float and gran == (128, 128):
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
sf = get_col_major_tma_aligned_packed_tensor(sf)
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
if should_skip_transform:
# TODO: add transpose kernel if SF layout is not satisfied
return check_sf_layout(sf,
mn=mn,
k=k,
gran=(1, 128),
num_groups=num_groups,
tma_stride_check=True,
type_check=torch.int)
assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
@triton.jit
def _silu_and_mul_post_quant_kernel(
input_ptr,
stride_input_0,
stride_input_1,
stride_input_2,
output_ptr,
stride_output_0,
stride_output_1,
stride_output_2,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
stride_output_scale_2,
masked_m_ptr,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
expert_id = tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 +
hidden_dim_block_index * stride_output_scale_1)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
up = tl.load(
input_ptr_offs + token_index * stride_input_1 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
)
gate = tl.load(
input_ptr_offs + token_index * stride_input_1 + size_k +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(gate_up / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_1 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_2,
output_s_int32,
)
def silu_and_mul_masked_post_quant_fwd(
input: torch.Tensor,
quant_group_size: int,
masked_m: torch.Tensor,
scale_ue8m0: bool = False,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.is_contiguous()
assert len(input.shape) == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = finfo.min
g, m, k = input.shape
k = k // 2
# Create output
output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda")
# Create output scale
alignment = 4
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.zeros((g, scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')
# Get block/grid/stage/warp
expert_num = len(masked_m)
if expert_num < 4:
BLOCK_NUM_PER_EXPERT = 64
else:
BLOCK_NUM_PER_EXPERT = 128
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
expert_num,
)
_silu_and_mul_post_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
masked_m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
output_scale = output_scale.transpose(1, 2)[:, :m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
g,
tma_stride_check=True,
)
return output, output_scale
@triton.jit
def _per_token_quant_and_transform_kernel(
input_ptr,
stride_input_0,
stride_input_1,
output_ptr,
stride_output_0,
stride_output_1,
output_scale_ptr,
stride_output_scale_0,
stride_output_scale_1,
token_num_cur_expert,
size_k,
fp8_max,
fp8_min,
BLOCK: tl.constexpr,
NUM_STAGE: tl.constexpr,
SCALE_UE8M0: tl.constexpr,
):
tl.program_id(2)
token_id = tl.program_id(1)
hidden_dim_block_index = tl.program_id(0)
block_num_per_expert = tl.num_programs(1)
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
input_ptr_offs = input_ptr + offs_in_d
output_ptr_offs = output_ptr + offs_in_d
output_scale_offs = (output_scale_ptr +
hidden_dim_block_index * stride_output_scale_0)
for token_index in tl.range(token_id,
token_num_cur_expert,
block_num_per_expert,
num_stages=NUM_STAGE):
output_s_int32 = 0
for pack_index in tl.range(4):
local_mask = offs_in_d + pack_index * 128
act = tl.load(
input_ptr_offs + token_index * stride_input_0 +
pack_index * 128,
mask=local_mask < size_k,
other=0.0,
).to(tl.float32)
_absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10)
output_s = _absmax / fp8_max
if SCALE_UE8M0:
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
output_q = tl.clamp(act / output_s, fp8_min,
fp8_max).to(output_ptr.dtype.element_ty)
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
(8 * pack_index))
tl.store(
output_ptr_offs + token_index * stride_output_0 +
pack_index * 128,
output_q,
mask=local_mask < size_k,
)
tl.store(
output_scale_offs + token_index * stride_output_scale_1,
output_s_int32,
)
def per_token_quant_and_transform(
input: torch.Tensor,
quant_group_size: int = 128,
scale_ue8m0: bool = True,
):
"""
input shape [g, m, k]
output shape [g, m, k // 2], dtype fp8
output_scale [g, k // 4, m // 2 // 128], dtype int32
quant_group_size int
masked_m shape [g]
"""
assert input.is_contiguous()
assert len(input.shape) == 2
assert input.shape[-1] % 2 == 0
# FP8 quantization parameters
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
m, k = input.shape
# Create output
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
# Create output scale
alignment = 4
scale_k = ceil_div(k, quant_group_size)
m_padded = align(m, alignment)
scale_k_padded = align(scale_k, alignment)
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
dtype=torch.int32,
device='cuda')
# Get block/grid/stage/warp
BLOCK_NUM_PER_EXPERT = 64
BLOCK = quant_group_size * 4
num_warps = 1
NUM_STAGES = 6
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
grid = (
hidden_dim_split_block_num,
BLOCK_NUM_PER_EXPERT,
1,
)
_per_token_quant_and_transform_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
output_scale,
*output_scale.stride(),
m,
k,
fp8_max,
fp8_min,
BLOCK=BLOCK,
NUM_STAGE=NUM_STAGES,
num_warps=num_warps,
SCALE_UE8M0=scale_ue8m0,
)
output_scale = output_scale.transpose(0, 1)[:m, :]
check_sf_layout(
output_scale,
m,
k,
(1, 128),
num_groups=None,
tma_stride_check=True,
)
return output, output_scale

View File

@ -8,6 +8,14 @@ def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def ceil_to_ue8m0(x: torch.Tensor):
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
@ -33,6 +41,33 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x_view.size(2))
def per_token_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
m, n), sf
def per_block_cast_to_fp8_e8m0(
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros((align(m, 128), align(n, 128)),
dtype=x.dtype,
device=x.device)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
sf = ceil_to_ue8m0(x_amax / 448.0)
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
x_view.size(0), x_view.size(2))
def calc_diff(x, y):
x, y = x.double(), y.double()
denominator = (x * x + y * y).sum()

View File

@ -9,7 +9,8 @@ import cloudpickle
import pytest
import torch
import torch.nn as nn
from _torch.helpers import per_block_cast_to_fp8
from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0,
per_token_cast_to_fp8_e8m0)
from mpi4py import MPI
from mpi4py.futures import MPIPoolExecutor
from utils.util import (skip_neither_ada_nor_hopper_unittest,
@ -25,6 +26,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
VanillaMoE, WideEPMoE)
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
CuteDslFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \
DeepGemmFusedMoE
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
AlltoallMethodType
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
@ -379,6 +382,183 @@ def set_tensor_value_4(x, num_row, num_cols):
x.copy_(repeated)
@skip_pre_blackwell
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
product(
[torch.bfloat16],
[72],
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
[2560],
[DefaultMoeRoutingMethod],
),
)
def test_fused_moe_fp8_blockwise_deepgemm(dtype,
num_experts,
seq_len,
hidden_size,
RoutingMethodCls,
mapping=None):
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
INTERMEDIATE_SIZE = 256
NUM_EXPERTS = num_experts
TOP_K = 2
routing_method = RoutingMethodCls(top_k=TOP_K)
mapping = mapping or Mapping()
mapping.rank = mpi_rank()
torch.cuda.set_device(mapping.rank)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
# Note: we use some special values init x and weight, otherwise the test will false positive failed.
set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE)
x = x.cuda()
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
weights = {}
w3_w1_weight_scales = []
w2_weight_scales = []
for expert_id in range(NUM_EXPERTS):
w1_weight = torch.randn(
(INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
dtype=dtype).cuda()
w3_weight = torch.randn(
(INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE
set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE)
set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0(w1_weight)
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0(w2_weight)
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0(w3_weight)
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
w3_w1_weight_scales.append(
torch.cat([w3_weight_scale, w1_weight_scale], dim=0))
w2_weight_scales.append(w2_weight_scale)
w3_w1_weight_scales = torch.stack(w3_w1_weight_scales, dim=0).cuda()
w2_weight_scales = torch.stack(w2_weight_scales, dim=0).cuda()
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
fused_moe = DeepGemmFusedMoE(
num_experts=NUM_EXPERTS,
routing_method=routing_method,
hidden_size=HIDDEN_SIZE,
intermediate_size=INTERMEDIATE_SIZE,
dtype=dtype,
reduce_results=True,
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
)
fused_moe.cuda()
fused_moe.load_weights([weights])
def swiglu_fused_moe(x):
x, gate = x.chunk(2, dim=-1)
return torch.nn.functional.silu(gate) * x
def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
b_sf: torch.Tensor,
offset_array: torch.Tensor) -> torch.Tensor:
d = torch.empty((a.shape[0], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
for idx in range(offset_array.numel() - 1):
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
num_groups, n, k_ = b.shape
d = torch.empty((a.shape[0], b.shape[1]),
device=b.device,
dtype=torch.bfloat16)
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
for idx in range(offset_array.numel() - 1):
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
for g in range(num_groups):
aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16)
aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :]
aa_dq = aa * aa_sf.repeat_interleave(
128, dim=1)[:aa.shape[0], :aa.shape[1]]
bb = b[g, :, :].to(torch.bfloat16)
bb_sf = b_sf[g, :, :]
bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave(
128, dim=1)[:bb.shape[0], :bb.shape[1]]
d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t())
return d
token_selected_experts, token_final_scales = routing_method.apply(
router_logits)
t_idx = 0
permuted_data_tensor = torch.empty((x.shape[0] * TOP_K, x.shape[1]),
device=x.device,
dtype=torch.bfloat16)
expert_first_token_offset_tensor = torch.zeros(NUM_EXPERTS + 1,
dtype=torch.int32)
unpermute_map = []
scales = []
for e_idx in range(NUM_EXPERTS):
for idx, token in enumerate(x):
for i, selected_expert in enumerate(token_selected_experts[idx]):
if e_idx == selected_expert:
permuted_data_tensor[t_idx, :] = token
unpermute_map.append(idx)
scales.append(token_final_scales[idx, i])
t_idx += 1
expert_first_token_offset_tensor[e_idx + 1] = t_idx
act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(
permuted_data_tensor)
h1 = grouped_gemm(
a=act_input_fp8,
b=fused_moe.w3_w1_weight,
a_sf=act_input_sf,
b_sf=w3_w1_weight_scales,
offset_array=expert_first_token_offset_tensor,
)
h2 = swiglu_fused_moe(h1)
act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(h2)
h3 = grouped_gemm(
a=act_input_fp8,
b=fused_moe.w2_weight,
a_sf=act_input_sf,
b_sf=w2_weight_scales,
offset_array=expert_first_token_offset_tensor,
)
ref_output = torch.zeros_like(x)
for token_idx, h3_token in enumerate(h3):
original_idx = unpermute_map[token_idx]
ref_output[original_idx, :] += h3_token * scales[token_idx]
with torch.inference_mode():
output = fused_moe.forward(x, router_logits)
# compare
torch.cuda.synchronize()
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
@skip_non_hopper_unittest
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",

View File

@ -18,10 +18,48 @@ import sys
import pytest
import torch
from _torch.helpers import calc_diff, per_block_cast_to_fp8
from _torch.helpers import (calc_diff, per_block_cast_to_fp8,
per_block_cast_to_fp8_e8m0,
per_token_cast_to_fp8_e8m0)
from utils.util import getSMVersion
@pytest.mark.skipif(
getSMVersion() != 100,
reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
(2048, 7168), (1024, 1024)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128, 4096],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
torch.random.manual_seed(0)
a = torch.randn((m, k), device='cuda', dtype=dtype)
b = torch.randn((n, k), device='cuda', dtype=dtype)
act_a_fp8, act_a_sf = per_token_cast_to_fp8_e8m0(a)
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
output_expected = a @ b.t()
import deep_gemm
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
device=act_a_fp8.device,
dtype=torch.bfloat16)
deep_gemm.fp8_gemm_nt((act_a_fp8, act_a_sf), (act_b_fp8, act_b_sf), output)
diff = calc_diff(output, output_expected)
assert diff < 1e-2
@pytest.mark.skipif(
getSMVersion() != 100 and getSMVersion() != 89,
reason="The test is for Blackwell and Ada only. Current SM is %d." %

View File

@ -51,6 +51,9 @@ def test_pip_install():
help="The wheel path")
args = parser.parse_args()
if not os.environ.get("CUDA_HOME"):
os.environ["CUDA_HOME"] = "/usr/local/cuda"
print("########## Install required system libs ##########")
if not os.path.exists("/usr/local/mpi/bin/mpicc"):
subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True)