mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Deepseek R1 FP8 Support on Blackwell (#6486)
Signed-off-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com> Co-authored-by: Barry Kang <43644113+Barry-Delaney@users.noreply.github.com> Co-authored-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Co-authored-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com>
This commit is contained in:
parent
8c165fd27a
commit
7bb0a78631
@ -50,7 +50,10 @@ def add_llm_args(parser):
|
||||
parser.add_argument('--moe_backend',
|
||||
type=str,
|
||||
default='CUTLASS',
|
||||
choices=['CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP'])
|
||||
choices=[
|
||||
'CUTLASS', 'TRTLLM', 'VANILLA', 'WIDEEP',
|
||||
'DEEPGEMM', 'CUTEDSL'
|
||||
])
|
||||
parser.add_argument('--enable_attention_dp',
|
||||
default=False,
|
||||
action='store_true')
|
||||
|
||||
@ -61,4 +61,5 @@ etcd3
|
||||
blake3
|
||||
llguidance==0.7.29
|
||||
soundfile
|
||||
deep_gemm @ git+https://github.com/zongfeijing/DeepGEMM.git@a9d538ef4dff0326fe521c6ca0bfde115703b56a
|
||||
triton==3.3.1
|
||||
|
||||
@ -12,7 +12,8 @@ from tensorrt_llm._torch.models.checkpoints.base_weight_loader import \
|
||||
BaseWeightLoader
|
||||
from tensorrt_llm._torch.models.modeling_utils import (
|
||||
register_checkpoint_weight_loader, run_concurrently)
|
||||
from tensorrt_llm._utils import local_mpi_rank, local_mpi_size
|
||||
from tensorrt_llm._utils import (local_mpi_barrier, local_mpi_rank,
|
||||
local_mpi_size)
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
|
||||
@ -38,6 +39,8 @@ class HfWeightLoader(BaseWeightLoader):
|
||||
f"Prefetching {prefetch_size / (1024**3):.2f}GB checkpoint files."
|
||||
)
|
||||
self.prefetch_files(weight_files)
|
||||
# Ensure that all local ranks have finished prefetching before loading weights
|
||||
local_mpi_barrier()
|
||||
|
||||
return self._load_weights_in_parallel(
|
||||
weight_files, self._load_safetensors_file,
|
||||
|
||||
@ -38,12 +38,15 @@ from torch import nn
|
||||
from tqdm import tqdm
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._ipc_utils import can_access_peer
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.functional import PositionEmbeddingType
|
||||
from tensorrt_llm.llmapi.utils import enable_llm_debug
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
|
||||
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
@ -1244,7 +1247,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
|
||||
local_num_heads = num_heads // weight_divisor
|
||||
|
||||
k_nope_weight_trans = k_nope_weight.transpose(2, 1)
|
||||
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()
|
||||
|
||||
kv_b_proj = torch.concat([
|
||||
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
|
||||
@ -1256,11 +1259,6 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
|
||||
return kv_b_proj, k_nope_weight_trans
|
||||
|
||||
def check_weight_dtype(module_name: str, dtype):
|
||||
weight_name = "weight"
|
||||
w_dtype = weights[f"{module_name}.{weight_name}"].dtype
|
||||
return w_dtype == dtype
|
||||
|
||||
def load_kv_b_proj_and_k_b_proj_trans_dequant(
|
||||
module_name: str) -> torch.Tensor:
|
||||
weight_name = "weight"
|
||||
@ -1290,7 +1288,7 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
weight_divisor = 1 if self.model_config.mapping.enable_attention_dp else tp_size
|
||||
local_num_heads = num_heads // weight_divisor
|
||||
|
||||
k_nope_weight_trans = k_nope_weight.transpose(2, 1)
|
||||
k_nope_weight_trans = k_nope_weight.transpose(2, 1).contiguous()
|
||||
|
||||
kv_b_proj = torch.concat([
|
||||
k_nope_weight.reshape(local_num_heads * local_qk_nope_head_dim,
|
||||
@ -1333,6 +1331,21 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
params_map = {'gate_up_proj': ['gate_proj', 'up_proj']}
|
||||
all_named_modules = dict(self.named_modules())
|
||||
|
||||
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
|
||||
) and get_sm_version() == 100:
|
||||
for name in list(weights.keys()):
|
||||
# Use ".experts." to exclude shared_experts.
|
||||
if name.endswith(
|
||||
"weight_scale_inv") and ".experts." not in name:
|
||||
weight_name = name.replace("weight_scale_inv", "weight")
|
||||
logger.debug(f"Resmoothing {weight_name}")
|
||||
weight = weights[weight_name][:]
|
||||
scale = weights[name][:]
|
||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||
weight, scale)
|
||||
weights[weight_name] = weights[weight_name].cpu()
|
||||
weights[name] = weights[name].cpu()
|
||||
|
||||
for name, module in tqdm(all_named_modules.items(),
|
||||
desc="Loading weights"):
|
||||
if len(module._parameters) > 0:
|
||||
@ -1384,6 +1397,26 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
attn_module.v_b_proj_scale = nn.Parameter(
|
||||
v_b_proj_scale, requires_grad=False)
|
||||
|
||||
if attn_module.k_b_proj_trans_dequant is not None:
|
||||
attn_module.k_b_proj_trans_dequant.data.copy_(
|
||||
weight_dequant(
|
||||
k_b_proj_trans.view(
|
||||
-1, k_b_proj_trans.shape[-1]).cuda(),
|
||||
k_b_proj_trans_scale.view(
|
||||
-1,
|
||||
k_b_proj_trans_scale.shape[-1]).cuda(),
|
||||
).view(
|
||||
*attn_module.k_b_proj_trans_dequant.shape).
|
||||
to(attn_module.k_b_proj_trans_dequant.dtype))
|
||||
if attn_module.v_b_proj_dequant is not None:
|
||||
attn_module.v_b_proj_dequant.data.copy_(
|
||||
weight_dequant(
|
||||
v_b_proj.view(-1,
|
||||
v_b_proj.shape[-1]).cuda(),
|
||||
v_b_proj_scale.view(
|
||||
-1, v_b_proj_scale.shape[-1]).cuda(),
|
||||
).view(*attn_module.v_b_proj_dequant.shape).to(
|
||||
attn_module.v_b_proj_dequant.dtype))
|
||||
elif names[-1] == "kv_a_proj_with_mqa":
|
||||
fused_a = weights[
|
||||
f"{'.'.join(names[:-1])}.kv_a_proj_with_mqa.weight"][:]
|
||||
@ -1431,6 +1464,18 @@ class DeepseekV3ForCausalLM(DecoderModelForCausalLM[DeepseekV3Model,
|
||||
for n, p in module.named_parameters():
|
||||
p.data.copy_(module_weights[n][:])
|
||||
|
||||
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
|
||||
) and get_sm_version() == 100 and hasattr(
|
||||
module, "weight_scale"):
|
||||
transfromed_scale = transform_sf_into_required_layout(
|
||||
module.weight_scale,
|
||||
mn=module.weight.shape[0],
|
||||
k=module.weight.shape[1],
|
||||
recipe=(1, 128, 128),
|
||||
is_sfa=False)
|
||||
module.weight_scale = nn.Parameter(transfromed_scale,
|
||||
requires_grad=False)
|
||||
|
||||
for idx, layer in enumerate(
|
||||
self.model.layers[:self.config.num_hidden_layers]):
|
||||
if idx == self.config.num_hidden_layers - 1:
|
||||
|
||||
@ -357,6 +357,7 @@ def fp8_block_scaling_bmm_out(
|
||||
mat2_fp8: torch.Tensor,
|
||||
mat2_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
mat2_dequant: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 90 or sm_version == 89:
|
||||
@ -365,30 +366,33 @@ def fp8_block_scaling_bmm_out(
|
||||
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
|
||||
mat1_scale, mat2_scale, out)
|
||||
elif sm_version == 100:
|
||||
low_latency = True
|
||||
use_deep_seek_fp8 = True
|
||||
tile_size = 8
|
||||
epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||
m_size = mat1.shape[0]
|
||||
if m_size % tile_size != 0:
|
||||
tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
|
||||
mat1 = torch.nn.functional.pad(
|
||||
mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)
|
||||
output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2))
|
||||
out.copy_(output)
|
||||
|
||||
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
mat1)
|
||||
output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||
mat1_fp8,
|
||||
mat2_fp8,
|
||||
tile_size=tile_size,
|
||||
epilogue_tile_m=epilogue_tile_m,
|
||||
use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
low_latency=low_latency,
|
||||
dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
|
||||
dq_sfs_b=mat2_scale,
|
||||
out_dtype=out.dtype,
|
||||
)
|
||||
out.copy_(output[:, :m_size])
|
||||
# low_latency = True
|
||||
# use_deep_seek_fp8 = True
|
||||
# tile_size = 8
|
||||
# epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
|
||||
# m_size = mat1.shape[0]
|
||||
# if m_size % tile_size != 0:
|
||||
# tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
|
||||
# mat1 = torch.nn.functional.pad(
|
||||
# mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)
|
||||
|
||||
# mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
# mat1)
|
||||
# output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
|
||||
# mat1_fp8,
|
||||
# mat2_fp8,
|
||||
# tile_size=tile_size,
|
||||
# epilogue_tile_m=epilogue_tile_m,
|
||||
# use_deep_seek_fp8=use_deep_seek_fp8,
|
||||
# low_latency=low_latency,
|
||||
# dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
|
||||
# dq_sfs_b=mat2_scale,
|
||||
# out_dtype=out.dtype,
|
||||
# )
|
||||
# out.copy_(output[:, :m_size])
|
||||
else:
|
||||
raise NotImplementedError(f"SM{sm_version} is not supported")
|
||||
|
||||
@ -676,6 +680,8 @@ class MLA(nn.Module):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.k_b_proj_trans_dequant = None
|
||||
self.v_b_proj_dequant = None
|
||||
if has_fp8_block_scales:
|
||||
self.k_b_proj_trans_scale = nn.Parameter(
|
||||
torch.empty(
|
||||
@ -701,6 +707,23 @@ class MLA(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if get_sm_version() == 100:
|
||||
assert self.dtype == torch.bfloat16
|
||||
self.k_b_proj_trans_dequant = nn.Parameter(
|
||||
torch.empty(
|
||||
(self.num_heads, self.kv_lora_rank,
|
||||
self.qk_nope_head_dim),
|
||||
dtype=self.dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.v_b_proj_dequant = nn.Parameter(
|
||||
torch.empty(
|
||||
(self.num_heads, self.v_head_dim, self.kv_lora_rank),
|
||||
dtype=self.dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
else:
|
||||
self.k_b_proj_trans_scale = None
|
||||
self.v_b_proj_scale = None
|
||||
@ -1197,8 +1220,13 @@ class MLA(nn.Module):
|
||||
# [num_heads, num_tokens, self.kv_lora_rank]
|
||||
q_nope_out = fused_q[..., :self.kv_lora_rank].transpose(0, 1)
|
||||
|
||||
fp8_block_scaling_bmm_out(q_nope, self.k_b_proj_trans,
|
||||
self.k_b_proj_trans_scale, q_nope_out)
|
||||
fp8_block_scaling_bmm_out(
|
||||
q_nope,
|
||||
self.k_b_proj_trans,
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Missing bmm impl for dtype: {self.k_b_proj_trans.dtype}.")
|
||||
@ -1247,9 +1275,13 @@ class MLA(nn.Module):
|
||||
self.v_b_proj.transpose(1, 2),
|
||||
attn_output.transpose(0, 1))
|
||||
elif self.v_b_proj.dtype == torch.float8_e4m3fn:
|
||||
fp8_block_scaling_bmm_out(attn_out_latent, self.v_b_proj,
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1))
|
||||
fp8_block_scaling_bmm_out(
|
||||
attn_out_latent,
|
||||
self.v_b_proj,
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Missing bmm impl for dtype: {self.v_b_proj.dtype}.")
|
||||
|
||||
@ -8,6 +8,7 @@ from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
from ...model_config import ModelConfig
|
||||
from .fused_moe_cute_dsl import CuteDslFusedMoE
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .fused_moe_deepgemm import DeepGemmFusedMoE
|
||||
from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE
|
||||
from .fused_moe_vanilla import VanillaMoE
|
||||
from .fused_moe_wide_ep import WideEPMoE
|
||||
@ -31,6 +32,8 @@ def get_moe_cls(
|
||||
return VanillaMoE
|
||||
elif moe_backend.upper() == "CUTEDSL":
|
||||
return CuteDslFusedMoE
|
||||
elif moe_backend.upper() == "DEEPGEMM":
|
||||
return DeepGemmFusedMoE
|
||||
elif moe_backend.upper() == "TRTLLM":
|
||||
if quant_config is not None and (
|
||||
quant_config.quant_mode.has_fp8_block_scales()
|
||||
@ -139,5 +142,19 @@ def create_moe(
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
elif moe_cls == DeepGemmFusedMoE:
|
||||
return moe_cls(
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported moe backend: {moe_cls}")
|
||||
|
||||
483
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Normal file
483
tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py
Normal file
@ -0,0 +1,483 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import deep_gemm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
from ...distributed import allgather
|
||||
from ...model_config import ModelConfig
|
||||
from ...utils import Fp4QuantizedTensor
|
||||
from .fused_moe_cutlass import CutlassFusedMoE
|
||||
from .quantization import MoEWeightLoadingMode
|
||||
from .routing import BaseMoeRoutingMethod
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _masked_index_copy_group_quant_fp8(
|
||||
input_ptr,
|
||||
out_q_ptr,
|
||||
out_s_ptr,
|
||||
# mask indices
|
||||
start_offsets_ptr,
|
||||
row_indices_ptr,
|
||||
# dimensions
|
||||
row_size,
|
||||
col_size,
|
||||
dim_size,
|
||||
group_size,
|
||||
# output scale factor size
|
||||
aligned_col,
|
||||
aligned_dim,
|
||||
# quantization parameters
|
||||
eps,
|
||||
fp8_max,
|
||||
# block size
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
):
|
||||
group_block = tl.program_id(0)
|
||||
token_block = tl.program_id(1)
|
||||
token_block_num = tl.num_programs(1)
|
||||
|
||||
# calculate group and element offsets
|
||||
num_tokens = tl.load(start_offsets_ptr + row_size)
|
||||
elem_offsets = group_block * group_size * 4 + tl.arange(0, BLOCK)
|
||||
output_s_offs = out_s_ptr + group_block * aligned_col
|
||||
|
||||
# process tokens
|
||||
for token_index in tl.range(token_block,
|
||||
num_tokens,
|
||||
token_block_num,
|
||||
num_stages=NUM_STAGE):
|
||||
# load indices
|
||||
row_idx = tl.load(row_indices_ptr + token_index)
|
||||
start_offset = tl.load(start_offsets_ptr + row_idx)
|
||||
idx = row_idx * col_size + token_index - start_offset
|
||||
idx_s = row_idx * aligned_dim * aligned_col + token_index - start_offset
|
||||
|
||||
output_s_int32 = 0
|
||||
for group_index in tl.range(4):
|
||||
# load input data
|
||||
dim_offset = elem_offsets + group_index * group_size
|
||||
valid = dim_offset < dim_size
|
||||
input_data = tl.load(input_ptr + token_index * dim_size +
|
||||
dim_offset,
|
||||
mask=valid,
|
||||
other=0.0)
|
||||
# quantization
|
||||
_absmax = tl.maximum(tl.max(tl.abs(input_data)), eps)
|
||||
output_s = _absmax / fp8_max
|
||||
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
||||
output_q = tl.clamp(input_data / output_s, -fp8_max,
|
||||
fp8_max).to(out_q_ptr.dtype.element_ty)
|
||||
output_s = output_s.to(tl.int32, bitcast=True) >> 23
|
||||
output_s_int32 += output_s << (group_index * 8)
|
||||
|
||||
# store quantized values and scaling factor
|
||||
tl.store(out_q_ptr + idx * dim_size + dim_offset,
|
||||
output_q,
|
||||
mask=valid)
|
||||
tl.store(output_s_offs + idx_s, output_s_int32)
|
||||
|
||||
|
||||
def masked_index_copy_group_quant_fp8(
|
||||
output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
start_offsets: torch.Tensor,
|
||||
row_indices: torch.Tensor,
|
||||
group_size: int,
|
||||
eps: float = 1e-10,
|
||||
):
|
||||
assert (
|
||||
input.shape[-1] % group_size == 0
|
||||
), "the last dimension of `input` cannot be divisible by `group_size`"
|
||||
assert input.is_contiguous(), "`input` is not contiguous"
|
||||
assert input.ndim == 2, "Input must be a 2D tensor"
|
||||
assert output.ndim == 3, "Output must be a 3D tensor, [row, col, dim]"
|
||||
assert start_offsets.shape[
|
||||
0] == output.shape[0] + 1, "Start offsets must be (num_experts + 1)"
|
||||
|
||||
num_tokens = input.shape[0]
|
||||
row_size = output.shape[0]
|
||||
col_size = output.shape[1]
|
||||
dim_size = output.shape[2]
|
||||
|
||||
# create padded output_s
|
||||
alignment = 4
|
||||
scale_dim = (dim_size + group_size - 1) // group_size
|
||||
padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment
|
||||
padded_col_size = (col_size + alignment - 1) // alignment * alignment
|
||||
output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
# get block/grid/stage/warp
|
||||
num_groups = (dim_size + group_size - 1) // group_size
|
||||
BLOCK = group_size
|
||||
if num_tokens <= 1000 or col_size <= 256: # Small workload
|
||||
TOKEN_BLOCK_NUM = 256
|
||||
NUM_STAGES = 4
|
||||
num_warps = 2
|
||||
elif num_tokens <= 10000 or col_size <= 2048: # Medium workload
|
||||
TOKEN_BLOCK_NUM = 1024
|
||||
NUM_STAGES = 2
|
||||
num_warps = 1
|
||||
else: # Large workload
|
||||
TOKEN_BLOCK_NUM = 2048
|
||||
NUM_STAGES = 2
|
||||
num_warps = 1
|
||||
grid = (
|
||||
(num_groups + 3) // 4,
|
||||
TOKEN_BLOCK_NUM,
|
||||
)
|
||||
|
||||
# FP8 quantization parameters
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
|
||||
_masked_index_copy_group_quant_fp8[grid](
|
||||
input,
|
||||
output,
|
||||
output_s,
|
||||
start_offsets,
|
||||
row_indices,
|
||||
row_size,
|
||||
col_size,
|
||||
dim_size,
|
||||
group_size,
|
||||
padded_col_size,
|
||||
padded_dim_size // 4,
|
||||
eps,
|
||||
fp8_max,
|
||||
BLOCK=BLOCK,
|
||||
NUM_STAGE=NUM_STAGES,
|
||||
num_warps=num_warps,
|
||||
)
|
||||
output_s = output_s.transpose(1, 2)[:, :col_size, :]
|
||||
return output_s
|
||||
|
||||
|
||||
@triton.jit
|
||||
def masked_index_gather_kernel(output_ptr, input_ptr, start_offsets_ptr,
|
||||
row_indices_ptr, row_size, col_size, dim_size,
|
||||
BLOCK_SIZE: tl.constexpr):
|
||||
# get program id and block offset
|
||||
pid = tl.program_id(0)
|
||||
num_tokens = tl.load(start_offsets_ptr + row_size)
|
||||
|
||||
token_idx = pid
|
||||
valid_token = token_idx < num_tokens
|
||||
if not valid_token:
|
||||
return
|
||||
|
||||
row_idx = tl.load(row_indices_ptr + token_idx)
|
||||
start_offset = tl.load(start_offsets_ptr + row_idx)
|
||||
col_idx = token_idx - start_offset
|
||||
|
||||
# Process elements in blocks
|
||||
for hidden_start in tl.range(0, dim_size, BLOCK_SIZE):
|
||||
hidden_indices = hidden_start + tl.arange(0, BLOCK_SIZE)
|
||||
valid_hidden = hidden_indices < dim_size
|
||||
|
||||
input_offset = row_idx * col_size * dim_size + col_idx * dim_size + hidden_indices
|
||||
input_val = tl.load(input_ptr + input_offset,
|
||||
mask=valid_hidden,
|
||||
other=0.0)
|
||||
|
||||
output_offset = pid * dim_size + hidden_indices
|
||||
tl.store(output_ptr + output_offset, input_val, mask=valid_hidden)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def triton_masked_index_gather(output, input, start_offsets, row_indices):
|
||||
assert output.ndim == 2, "Output must be a 2D tensor"
|
||||
assert input.ndim == 3, "Input must be a 3D tensor, [row, col, dim]"
|
||||
assert start_offsets.shape[
|
||||
0] == input.shape[0] + 1, "Start offsets must be (num_experts + 1)"
|
||||
|
||||
row_size = input.shape[0]
|
||||
col_size = input.shape[1]
|
||||
dim_size = input.shape[2]
|
||||
num_tokens = output.shape[0]
|
||||
|
||||
grid = (num_tokens, )
|
||||
# launch kernel
|
||||
masked_index_gather_kernel[grid](output,
|
||||
input,
|
||||
start_offsets,
|
||||
row_indices,
|
||||
row_size,
|
||||
col_size,
|
||||
dim_size,
|
||||
BLOCK_SIZE=1024)
|
||||
return
|
||||
|
||||
|
||||
@nvtx_range("[DG] act")
|
||||
@torch.compile(dynamic=True)
|
||||
def swiglu_fused_moe(x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return F.silu(gate) * x
|
||||
|
||||
|
||||
@nvtx_range("[DG] indexing")
|
||||
@torch.compile(dynamic=True)
|
||||
def indexing(x, mask):
|
||||
return x[mask > 0, :].contiguous()
|
||||
|
||||
|
||||
@nvtx_range("[DG] preprocess_after_permute")
|
||||
def preprocess_after_permute(expert_first_token_offset_tensor,
|
||||
permuted_data_tensor):
|
||||
# get tokens per expert
|
||||
masked_m = expert_first_token_offset_tensor[
|
||||
1:] - expert_first_token_offset_tensor[:-1]
|
||||
token_to_expert_map = torch.searchsorted(
|
||||
expert_first_token_offset_tensor[1:],
|
||||
torch.arange(permuted_data_tensor.shape[0], device='cuda'),
|
||||
right=True)
|
||||
return masked_m.to(torch.int32), token_to_expert_map
|
||||
|
||||
|
||||
@nvtx_range("[DG]")
|
||||
def deepgemm_fp8_group_blockwise_gemm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
sfa: torch.Tensor,
|
||||
sfb: torch.Tensor,
|
||||
masked_m: torch.Tensor,
|
||||
expected_m: int,
|
||||
) -> torch.Tensor:
|
||||
d = torch.empty((a.shape[0], a.shape[1], b.shape[1]),
|
||||
device=b.device,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
# NOTES: shape must be `[G, M, K] @ [G, N, K].mT`
|
||||
assert a.stride(-1) == 1
|
||||
assert b.stride(-1) == 1
|
||||
assert masked_m.is_contiguous()
|
||||
|
||||
num_groups, m, k = a.shape
|
||||
num_groups_, n, k_ = b.shape
|
||||
num_groups__, m_, n_ = d.shape
|
||||
num_groups___ = masked_m.numel()
|
||||
|
||||
# Type and shape checks
|
||||
assert num_groups == num_groups_ == num_groups__ == num_groups___
|
||||
assert m == m_ and n == n_ and k == k_
|
||||
assert expected_m > 0 and m > 0 and n > 0 and k > 0 and num_groups > 0
|
||||
assert a.dtype == torch.float8_e4m3fn
|
||||
assert b.dtype == torch.float8_e4m3fn
|
||||
assert d.dtype == torch.bfloat16
|
||||
assert masked_m.dtype == torch.int32
|
||||
|
||||
# D must be N-major
|
||||
assert d.stride(-1) == 1
|
||||
|
||||
# Transform SFA and SFB into compute-required layout
|
||||
|
||||
deep_gemm.fp8_m_grouped_gemm_nt_masked((a, sfa), (b, sfb),
|
||||
d,
|
||||
masked_m,
|
||||
expected_m,
|
||||
disable_ue8m0_cast=True)
|
||||
return d
|
||||
|
||||
|
||||
class DeepGemmFusedMoE(CutlassFusedMoE):
|
||||
"""
|
||||
Python Flow of Fused Mixture of Experts (MoE) Layer.
|
||||
|
||||
Args:
|
||||
num_experts (int): Number of experts in the MoE layer.
|
||||
top_k (int): Number of top experts to select for each input token.
|
||||
hidden_size (int): Size of the hidden state.
|
||||
intermediate_size (int): Size of the intermediate state.
|
||||
aux_stream (Optional[torch.cuda.Stream]): Auxiliary CUDA stream to overlap chunks.
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
|
||||
This backend is composed of multiple custom ops:
|
||||
1. moe_permute_op: permute the input tensor and the expert selected tensor.
|
||||
2. cute_dsl_fp8_group_blockwise_gemm_ref: a reference implementation of the cute_dsl_fp8_group_blockwise_gemm.
|
||||
3. moe_finalize_scale_op: finalize the scale of the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
routing_method: BaseMoeRoutingMethod,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
model_config: ModelConfig = ModelConfig(),
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.
|
||||
VANILLA,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
routing_method=routing_method,
|
||||
num_experts=num_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
dtype=dtype,
|
||||
reduce_results=reduce_results,
|
||||
model_config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
weight_loading_mode=weight_loading_mode,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
@nvtx_range("[DG] forward")
|
||||
def forward_chunk(
|
||||
self,
|
||||
x: Union[torch.Tensor, Fp4QuantizedTensor],
|
||||
router_logits: torch.Tensor,
|
||||
output_dtype: Optional[torch.dtype] = None,
|
||||
all_rank_num_tokens: Optional[List[int]] = None,
|
||||
use_dp_padding: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(x, Fp4QuantizedTensor):
|
||||
assert output_dtype is not None
|
||||
output_dtype = output_dtype
|
||||
else:
|
||||
output_dtype = x.dtype
|
||||
|
||||
# apply routing
|
||||
token_selected_experts, token_final_scales = self.routing_method.apply(
|
||||
router_logits)
|
||||
assert token_selected_experts.shape[
|
||||
1] == self.routing_method.experts_per_token
|
||||
assert token_selected_experts.shape == token_final_scales.shape
|
||||
assert token_selected_experts.shape[0] == router_logits.shape[0]
|
||||
assert token_final_scales.dtype == torch.float32
|
||||
assert token_selected_experts.dtype == torch.int32
|
||||
|
||||
if self.apply_router_weight_on_input:
|
||||
assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing"
|
||||
assert x.dtype != torch.float8_e4m3fn, "Current workaround for apply_router_weight_on_input does not support fp8 input"
|
||||
x = x * token_final_scales.to(x.dtype)
|
||||
# TODO: remove this once we have correct fusedmoe kernel ready
|
||||
token_final_scales = None
|
||||
|
||||
# quantize inputs
|
||||
use_deepseek_fp8_block_scale = False
|
||||
x_sf = None
|
||||
if self.has_any_quant:
|
||||
if self.has_deepseek_fp8_block_scales:
|
||||
use_deepseek_fp8_block_scale = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"unsupported quantization mode for CUTEDSL backend: {self.quant_config.quant_mode}"
|
||||
)
|
||||
|
||||
use_allgather = self.use_dp and self.parallel_size > 1
|
||||
if use_allgather:
|
||||
x, x_sf, token_selected_experts, token_final_scales = allgather(
|
||||
[x, x_sf, token_selected_experts, token_final_scales],
|
||||
self.mapping,
|
||||
dim=0,
|
||||
sizes=None if use_dp_padding else all_rank_num_tokens)
|
||||
|
||||
(
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
permuted_token_selected_experts_tensor,
|
||||
permuted_data_tensor,
|
||||
expert_first_token_offset_tensor,
|
||||
permuted_token_final_scales_tensor,
|
||||
unpermuted_row_to_permuted_row_tensor,
|
||||
) = torch.ops.trtllm.moe_permute_op(
|
||||
x,
|
||||
token_selected_experts,
|
||||
token_final_scales,
|
||||
None, # w3_w1_weight.view(weight_dtype),
|
||||
None, # w2_weight.view(weight_dtype),
|
||||
None, # quant_scales,
|
||||
input_sf=x_sf,
|
||||
num_experts_on_rank=self.expert_size_per_partition,
|
||||
tp_size=self.tp_size,
|
||||
tp_rank=self.tp_rank,
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
cluster_size=self.cluster_size,
|
||||
cluster_rank=self.cluster_rank,
|
||||
min_latency_mode=False,
|
||||
use_fp8_block_scaling=use_deepseek_fp8_block_scale,
|
||||
)
|
||||
|
||||
if permuted_data_tensor.numel() == 0:
|
||||
return torch.zeros_like(x)
|
||||
|
||||
masked_m, token_to_expert_map = preprocess_after_permute(
|
||||
expert_first_token_offset_tensor, permuted_data_tensor)
|
||||
|
||||
m_max = (x.shape[0] + 127) // 128 * 128
|
||||
expected_m = (token_selected_experts.numel() +
|
||||
self.expert_size_per_partition -
|
||||
1) // self.expert_size_per_partition
|
||||
act_input_fp8 = torch.empty(
|
||||
(self.expert_size_per_partition, m_max, self.hidden_size),
|
||||
dtype=torch.float8_e4m3fn,
|
||||
device='cuda')
|
||||
act_input_sf = masked_index_copy_group_quant_fp8(
|
||||
act_input_fp8,
|
||||
permuted_data_tensor,
|
||||
expert_first_token_offset_tensor,
|
||||
token_to_expert_map,
|
||||
group_size=128)
|
||||
|
||||
h1 = deepgemm_fp8_group_blockwise_gemm(
|
||||
a=act_input_fp8,
|
||||
b=self.w3_w1_weight,
|
||||
sfa=act_input_sf,
|
||||
sfb=self.quant_scales[0],
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
)
|
||||
act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd(
|
||||
input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True)
|
||||
h3 = deepgemm_fp8_group_blockwise_gemm(
|
||||
a=act_input_fp8,
|
||||
b=self.w2_weight,
|
||||
sfa=act_input_sf,
|
||||
sfb=self.quant_scales[1],
|
||||
masked_m=masked_m,
|
||||
expected_m=expected_m,
|
||||
)
|
||||
|
||||
triton_masked_index_gather(permuted_data_tensor, h3,
|
||||
expert_first_token_offset_tensor,
|
||||
token_to_expert_map)
|
||||
|
||||
final_hidden_states = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
permuted_data_tensor,
|
||||
None, # biases
|
||||
token_final_scales,
|
||||
unpermuted_row_to_permuted_row_tensor,
|
||||
permuted_row_to_unpermuted_row_tensor,
|
||||
token_selected_experts,
|
||||
expert_first_token_offset_tensor,
|
||||
False, # enable_alltoall
|
||||
x.shape[0], # num_rows
|
||||
x.shape[1], # hidden_size
|
||||
self.routing_method.top_k,
|
||||
self.expert_size_per_partition, # num_experts_per_node
|
||||
self.tp_size,
|
||||
self.tp_rank,
|
||||
self.ep_size,
|
||||
self.ep_rank,
|
||||
)
|
||||
|
||||
return final_hidden_states
|
||||
@ -4,10 +4,13 @@ from typing import Dict, List, NamedTuple, Union
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from tensorrt_llm import logger
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.quantization.utils.fp4_utils import (
|
||||
float4_sf_dtype, get_reorder_rows_for_gated_act_gemm_row_indices,
|
||||
get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices)
|
||||
from tensorrt_llm.quantization.utils.fp8_utils import (
|
||||
resmooth_to_fp8_e8m0, transform_sf_into_required_layout)
|
||||
|
||||
from ..linear import TensorParallelMode, load_weight_shard
|
||||
from .interface import MoEWeightLoadingMode
|
||||
@ -463,6 +466,47 @@ class DeepSeekFP8BlockScalesFusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
def load_weights(self, module: torch.nn.Module, weights: List[Dict],
|
||||
weight_loading_mode: MoEWeightLoadingMode):
|
||||
|
||||
if get_sm_version() == 100:
|
||||
expert_ids = set(module.initial_local_expert_ids)
|
||||
if self.need_load_shared_weights(module):
|
||||
expert_ids.update(
|
||||
module.layer_load_balancer.get_load_expert_ids())
|
||||
for name in list(weights.keys()):
|
||||
if name.endswith("weight_scale_inv"):
|
||||
if int(name.split(".")[0]) not in expert_ids:
|
||||
continue
|
||||
weight_name = name.replace("weight_scale_inv", "weight")
|
||||
logger.debug(f"Resmoothing {weight_name}")
|
||||
weight = weights[weight_name][:]
|
||||
scale = weights[name][:]
|
||||
weights[weight_name], weights[name] = resmooth_to_fp8_e8m0(
|
||||
weight, scale)
|
||||
super().load_weights(module, weights, weight_loading_mode)
|
||||
|
||||
if get_sm_version() == 100:
|
||||
transfromed_w3_w1_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[0],
|
||||
mn=module.w3_w1_weight.shape[1],
|
||||
k=module.w3_w1_weight.shape[2],
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
module.w3_w1_weight_scaling_factor = nn.Parameter(
|
||||
transfromed_w3_w1_scale, requires_grad=False)
|
||||
transfromed_w2_scale = transform_sf_into_required_layout(
|
||||
module.quant_scales[1],
|
||||
mn=module.w2_weight.shape[1],
|
||||
k=module.w2_weight.shape[2],
|
||||
recipe=(1, 128, 128),
|
||||
num_groups=module.w3_w1_weight.shape[0],
|
||||
is_sfa=False)
|
||||
module.w2_weight_scaling_factor = nn.Parameter(transfromed_w2_scale,
|
||||
requires_grad=False)
|
||||
self.setup_quant_scales(module)
|
||||
|
||||
def setup_quant_scales(self, module: torch.nn.Module):
|
||||
module.quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales(
|
||||
fc_weight_scales=module.w3_w1_weight_scaling_factor,
|
||||
|
||||
@ -12,6 +12,7 @@ from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
|
||||
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
|
||||
AllReduceStrategy)
|
||||
@ -20,6 +21,7 @@ from tensorrt_llm.quantization.functional import \
|
||||
preprocess_weights_for_mixed_gemm
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
@ -570,10 +572,22 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
input = input.to(torch.bfloat16) * module.input_scale
|
||||
assert input.dtype == torch.bfloat16
|
||||
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(input)
|
||||
if get_sm_version() == 100:
|
||||
import deep_gemm
|
||||
a, a_sf = fp8_utils.per_token_quant_and_transform(input)
|
||||
output = torch.empty((input.shape[0], module.weight.shape[0]),
|
||||
device=input.device,
|
||||
dtype=torch.bfloat16)
|
||||
deep_gemm.fp8_gemm_nt((a, a_sf),
|
||||
(module.weight, module.weight_scale),
|
||||
output,
|
||||
disable_ue8m0_cast=True)
|
||||
else:
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
input)
|
||||
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
@ -603,7 +617,6 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((q_weight, k_weight, v_weight))
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
@ -614,6 +627,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
|
||||
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_fp8_block_scale)
|
||||
|
||||
def load_weights_fused_gate_up_linear(self, module: Linear,
|
||||
@ -621,7 +635,6 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
gate_weight, up_weight = load_weights_fused_gate_up_helper(
|
||||
module, weights)
|
||||
fused_weight = torch.cat((gate_weight, up_weight))
|
||||
copy_weight(module.weight, fused_weight)
|
||||
|
||||
scale_name = self._get_scale_name(weights)
|
||||
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
|
||||
@ -629,6 +642,7 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
|
||||
module.tp_rank, module.tp_mode)
|
||||
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
|
||||
copy_weight(module.weight, fused_weight)
|
||||
copy_weight(module.weight_scale, fused_scale)
|
||||
|
||||
|
||||
|
||||
@ -144,6 +144,8 @@ class KvCacheCreator:
|
||||
end_id=-1)
|
||||
requests.append(request)
|
||||
remaining_tokens -= input_seq_len
|
||||
if self._mapping.enable_attention_dp:
|
||||
requests = requests * self._mapping.tp_size
|
||||
return requests
|
||||
|
||||
def _get_token_num_for_estimation(self) -> int:
|
||||
|
||||
@ -470,6 +470,10 @@ def mpi_comm():
|
||||
local_comm = mpi_comm().Split_type(split_type=OMPI_COMM_TYPE_HOST)
|
||||
|
||||
|
||||
def local_mpi_comm():
|
||||
return local_comm
|
||||
|
||||
|
||||
def mpi_rank():
|
||||
return mpi_comm().Get_rank() if ENABLE_MULTI_DEVICE else 0
|
||||
|
||||
@ -508,6 +512,11 @@ def mpi_barrier():
|
||||
mpi_comm().Barrier()
|
||||
|
||||
|
||||
def local_mpi_barrier():
|
||||
if ENABLE_MULTI_DEVICE:
|
||||
local_comm.Barrier()
|
||||
|
||||
|
||||
def mpi_broadcast(obj, root=0):
|
||||
return mpi_comm().bcast(obj, root) if is_multi_device_enable() else obj
|
||||
|
||||
|
||||
@ -167,7 +167,7 @@ class MoeConfig(StrictBaseModel):
|
||||
"""
|
||||
Configuration for MoE.
|
||||
"""
|
||||
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM",
|
||||
backend: Literal["CUTLASS", "CUTEDSL", "WIDEEP", "TRTLLM", "DEEPGEMM",
|
||||
"VANILLA"] = Field(default='CUTLASS',
|
||||
description="MoE backend to use.")
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
from . import fp4_utils
|
||||
from . import fp4_utils, fp8_utils
|
||||
|
||||
__all__ = ['fp4_utils']
|
||||
__all__ = ['fp4_utils', 'fp8_utils']
|
||||
|
||||
530
tensorrt_llm/quantization/utils/fp8_utils.py
Normal file
530
tensorrt_llm/quantization/utils/fp8_utils.py
Normal file
@ -0,0 +1,530 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from tensorrt_llm._utils import nvtx_range
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
"""
|
||||
Perform ceiling division of two integers.
|
||||
|
||||
Args:
|
||||
x: the dividend.
|
||||
y: the divisor.
|
||||
|
||||
Returns:
|
||||
The result of the ceiling division.
|
||||
"""
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def align(x: int, y: int) -> int:
|
||||
return ceil_div(x, y) * y
|
||||
|
||||
|
||||
def ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
@nvtx_range("[DG] quantization")
|
||||
@torch.compile(dynamic=True)
|
||||
def per_token_cast_to_fp8_e8m0(
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.dim() == 2:
|
||||
assert x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n), sf
|
||||
else:
|
||||
assert x.size(2) % 128 == 0
|
||||
g, m, n = x.shape
|
||||
x_view = x.view(g, m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=3).view(g, m, -1).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
return (x_view * (1.0 / sf.unsqueeze(3))).to(torch.float8_e4m3fn).view(
|
||||
g, m, n), sf
|
||||
|
||||
|
||||
def per_block_cast_to_fp8_e8m0(
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if x.dim() == 2:
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((align(m, 128), align(n, 128)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2))
|
||||
else:
|
||||
g, m, n = x.shape
|
||||
x_padded = torch.zeros((g, align(m, 128), align(n, 128)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:, :m, :n] = x
|
||||
x_view = x_padded.view(g, -1, 128, x_padded.size(-1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(2, 4), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:, :m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(1), x_view.size(3))
|
||||
|
||||
|
||||
def resmooth_to_fp8_e8m0(weight: torch.Tensor,
|
||||
sf: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
weight = weight.cuda()
|
||||
sf = sf.cuda()
|
||||
if weight.dim() == 2:
|
||||
x = weight.float() * sf.repeat_interleave(128, dim=0).repeat_interleave(
|
||||
128, dim=1)[:weight.shape[0], :weight.shape[1]]
|
||||
else:
|
||||
x = weight.float() * sf.repeat_interleave(128, dim=1).repeat_interleave(
|
||||
128, dim=2)[:weight.shape[0], :weight.shape[1], :weight.shape[2]]
|
||||
return per_block_cast_to_fp8_e8m0(x)
|
||||
|
||||
|
||||
def get_m_alignment_for_contiguous_layout():
|
||||
return 128
|
||||
|
||||
|
||||
def get_tma_aligned_size(x: int, element_size: int) -> int:
|
||||
tma_alignment_bytes = 16
|
||||
assert tma_alignment_bytes % element_size == 0
|
||||
alignment = tma_alignment_bytes // element_size
|
||||
return align(x, alignment)
|
||||
|
||||
|
||||
def get_col_major_tma_aligned_packed_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTES: for the extreme performance, you may rewrite/fuse this function in CUDA
|
||||
assert x.dtype == torch.float and x.dim() in (2, 3)
|
||||
|
||||
# First, convert into UE8M0 `uint8_t`
|
||||
ue8m0_tensor = (x.view(torch.int) >> 23).to(torch.uint8)
|
||||
|
||||
# Second, make padded packed tensors
|
||||
mn, k = x.shape[-2], x.shape[-1]
|
||||
remove_dim = False
|
||||
if x.dim() == 2:
|
||||
x, remove_dim = x.unsqueeze(0), True
|
||||
b = x.shape[0]
|
||||
aligned_mn = get_tma_aligned_size(mn, 4)
|
||||
aligned_k = align(k, 4)
|
||||
padded = torch.zeros((b, aligned_mn, aligned_k),
|
||||
device=x.device,
|
||||
dtype=torch.uint8)
|
||||
padded[:, :mn, :k] = ue8m0_tensor
|
||||
padded = padded.view(-1).view(dtype=torch.int).view(b, aligned_mn,
|
||||
aligned_k // 4)
|
||||
|
||||
# Finally, transpose
|
||||
transposed = torch.transpose(
|
||||
torch.empty((b, aligned_k // 4, aligned_mn),
|
||||
device=x.device,
|
||||
dtype=torch.int), 1, 2)
|
||||
transposed[:, :, :] = padded
|
||||
aligned_x = transposed[:, :mn, :]
|
||||
return aligned_x.squeeze(0) if remove_dim else aligned_x
|
||||
|
||||
|
||||
def check_sf_layout(sf: torch.Tensor,
|
||||
mn: int,
|
||||
k: int,
|
||||
gran: Tuple[int, int],
|
||||
num_groups: Optional[int],
|
||||
tma_stride_check: bool = False,
|
||||
type_check: Optional[torch.dtype] = None) -> torch.Tensor:
|
||||
# Type check
|
||||
if type_check is not None:
|
||||
assert sf.dtype == type_check
|
||||
|
||||
# Always do shape checks
|
||||
assert sf.dtype in (torch.float, torch.int)
|
||||
assert sf.dim() == int(num_groups is not None) + 2
|
||||
if num_groups is not None:
|
||||
assert sf.size(-3) == num_groups
|
||||
assert sf.size(-2) == ceil_div(mn, gran[0])
|
||||
assert sf.size(-1) == ceil_div(
|
||||
k, gran[1] * (1 if sf.dtype == torch.float else 4))
|
||||
|
||||
# TMA stride checks: TMA aligned and MN-major
|
||||
if tma_stride_check:
|
||||
if num_groups is not None:
|
||||
assert sf.stride(-3) == sf.stride(-1) * sf.size(-1)
|
||||
assert sf.stride(-2) == 1
|
||||
assert sf.stride(-1) == get_tma_aligned_size(mn, sf.element_size())
|
||||
|
||||
return sf
|
||||
|
||||
|
||||
@nvtx_range("[DG] transform_sf_into_required_layout")
|
||||
def transform_sf_into_required_layout(sf: torch.Tensor,
|
||||
mn: int,
|
||||
k: int,
|
||||
recipe: Tuple[int, int, int],
|
||||
num_groups: Optional[int] = None,
|
||||
is_sfa: bool = False):
|
||||
gran = (recipe[0 if is_sfa else 1], recipe[2])
|
||||
|
||||
should_skip_transform = ((sf.dtype == torch.int and gran == (1, 128))
|
||||
or (sf.dtype == torch.int and gran == (128, 128)))
|
||||
|
||||
if not should_skip_transform:
|
||||
# Pre-transform checks
|
||||
check_sf_layout(sf, mn=mn, k=k, gran=gran, num_groups=num_groups)
|
||||
|
||||
# (FP32, 1, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if sf.dtype == torch.float and gran == (1, 128):
|
||||
sf = get_col_major_tma_aligned_packed_tensor(sf)
|
||||
return check_sf_layout(sf,
|
||||
mn=mn,
|
||||
k=k,
|
||||
gran=(1, 128),
|
||||
num_groups=num_groups,
|
||||
tma_stride_check=True,
|
||||
type_check=torch.int)
|
||||
|
||||
# (FP32, 128, 128) on SM100: transform to (INT, 1, 128), TMA-aligned and MN-major
|
||||
if sf.dtype == torch.float and gran == (128, 128):
|
||||
sf = sf.index_select(-2, torch.arange(mn, device=sf.device) // 128)
|
||||
sf = get_col_major_tma_aligned_packed_tensor(sf)
|
||||
return check_sf_layout(sf,
|
||||
mn=mn,
|
||||
k=k,
|
||||
gran=(1, 128),
|
||||
num_groups=num_groups,
|
||||
tma_stride_check=True,
|
||||
type_check=torch.int)
|
||||
|
||||
if should_skip_transform:
|
||||
# TODO: add transpose kernel if SF layout is not satisfied
|
||||
return check_sf_layout(sf,
|
||||
mn=mn,
|
||||
k=k,
|
||||
gran=(1, 128),
|
||||
num_groups=num_groups,
|
||||
tma_stride_check=True,
|
||||
type_check=torch.int)
|
||||
|
||||
assert False, f'Unknown cases: {sf.dtype=}, {gran=}'
|
||||
|
||||
|
||||
# copy from https://github.com/ModelTC/lightllm/blob/a000ab69098654df4731f5b12587dd4e7f0a4f41/lightllm/common/fused_moe/moe_silu_and_mul_mix_quant_ep.py
|
||||
@triton.jit
|
||||
def _silu_and_mul_post_quant_kernel(
|
||||
input_ptr,
|
||||
stride_input_0,
|
||||
stride_input_1,
|
||||
stride_input_2,
|
||||
output_ptr,
|
||||
stride_output_0,
|
||||
stride_output_1,
|
||||
stride_output_2,
|
||||
output_scale_ptr,
|
||||
stride_output_scale_0,
|
||||
stride_output_scale_1,
|
||||
stride_output_scale_2,
|
||||
masked_m_ptr,
|
||||
size_k,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
SCALE_UE8M0: tl.constexpr,
|
||||
):
|
||||
expert_id = tl.program_id(2)
|
||||
token_id = tl.program_id(1)
|
||||
hidden_dim_block_index = tl.program_id(0)
|
||||
|
||||
block_num_per_expert = tl.num_programs(1)
|
||||
|
||||
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
|
||||
|
||||
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
||||
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
||||
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
||||
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
||||
|
||||
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
|
||||
input_ptr_offs = input_ptr + expert_id * stride_input_0 + offs_in_d
|
||||
output_ptr_offs = output_ptr + expert_id * stride_output_0 + offs_in_d
|
||||
output_scale_offs = (output_scale_ptr + expert_id * stride_output_scale_0 +
|
||||
hidden_dim_block_index * stride_output_scale_1)
|
||||
|
||||
for token_index in tl.range(token_id,
|
||||
token_num_cur_expert,
|
||||
block_num_per_expert,
|
||||
num_stages=NUM_STAGE):
|
||||
output_s_int32 = 0
|
||||
for pack_index in tl.range(4):
|
||||
local_mask = offs_in_d + pack_index * 128
|
||||
up = tl.load(
|
||||
input_ptr_offs + token_index * stride_input_1 +
|
||||
pack_index * 128,
|
||||
mask=local_mask < size_k,
|
||||
other=0.0,
|
||||
)
|
||||
gate = tl.load(
|
||||
input_ptr_offs + token_index * stride_input_1 + size_k +
|
||||
pack_index * 128,
|
||||
mask=local_mask < size_k,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
gate = gate / (1 + tl.exp(-gate))
|
||||
gate = gate.to(input_ptr.dtype.element_ty)
|
||||
gate_up = up * gate
|
||||
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
||||
output_s = _absmax / fp8_max
|
||||
if SCALE_UE8M0:
|
||||
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
||||
output_q = tl.clamp(gate_up / output_s, fp8_min,
|
||||
fp8_max).to(output_ptr.dtype.element_ty)
|
||||
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
|
||||
(8 * pack_index))
|
||||
tl.store(
|
||||
output_ptr_offs + token_index * stride_output_1 +
|
||||
pack_index * 128,
|
||||
output_q,
|
||||
mask=local_mask < size_k,
|
||||
)
|
||||
tl.store(
|
||||
output_scale_offs + token_index * stride_output_scale_2,
|
||||
output_s_int32,
|
||||
)
|
||||
|
||||
|
||||
def silu_and_mul_masked_post_quant_fwd(
|
||||
input: torch.Tensor,
|
||||
quant_group_size: int,
|
||||
masked_m: torch.Tensor,
|
||||
scale_ue8m0: bool = False,
|
||||
):
|
||||
"""
|
||||
input shape [g, m, k]
|
||||
output shape [g, m, k // 2], dtype fp8
|
||||
output_scale [g, k // 4, m // 2 // 128], dtype int32
|
||||
quant_group_size int
|
||||
masked_m shape [g]
|
||||
"""
|
||||
|
||||
assert input.is_contiguous()
|
||||
assert len(input.shape) == 3
|
||||
assert input.shape[0] == masked_m.shape[0]
|
||||
assert input.shape[-1] % 2 == 0
|
||||
|
||||
# FP8 quantization parameters
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = finfo.min
|
||||
|
||||
g, m, k = input.shape
|
||||
k = k // 2
|
||||
|
||||
# Create output
|
||||
output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
|
||||
# Create output scale
|
||||
alignment = 4
|
||||
scale_k = ceil_div(k, quant_group_size)
|
||||
m_padded = align(m, alignment)
|
||||
scale_k_padded = align(scale_k, alignment)
|
||||
output_scale = torch.zeros((g, scale_k_padded // 4, m_padded),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
# Get block/grid/stage/warp
|
||||
expert_num = len(masked_m)
|
||||
|
||||
if expert_num < 4:
|
||||
BLOCK_NUM_PER_EXPERT = 64
|
||||
else:
|
||||
BLOCK_NUM_PER_EXPERT = 128
|
||||
|
||||
BLOCK = quant_group_size * 4
|
||||
num_warps = 1
|
||||
NUM_STAGES = 6
|
||||
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
|
||||
grid = (
|
||||
hidden_dim_split_block_num,
|
||||
BLOCK_NUM_PER_EXPERT,
|
||||
expert_num,
|
||||
)
|
||||
_silu_and_mul_post_quant_kernel[grid](
|
||||
input,
|
||||
*input.stride(),
|
||||
output,
|
||||
*output.stride(),
|
||||
output_scale,
|
||||
*output_scale.stride(),
|
||||
masked_m,
|
||||
k,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK=BLOCK,
|
||||
NUM_STAGE=NUM_STAGES,
|
||||
num_warps=num_warps,
|
||||
SCALE_UE8M0=scale_ue8m0,
|
||||
)
|
||||
output_scale = output_scale.transpose(1, 2)[:, :m, :]
|
||||
check_sf_layout(
|
||||
output_scale,
|
||||
m,
|
||||
k,
|
||||
(1, 128),
|
||||
g,
|
||||
tma_stride_check=True,
|
||||
)
|
||||
return output, output_scale
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_token_quant_and_transform_kernel(
|
||||
input_ptr,
|
||||
stride_input_0,
|
||||
stride_input_1,
|
||||
output_ptr,
|
||||
stride_output_0,
|
||||
stride_output_1,
|
||||
output_scale_ptr,
|
||||
stride_output_scale_0,
|
||||
stride_output_scale_1,
|
||||
token_num_cur_expert,
|
||||
size_k,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK: tl.constexpr,
|
||||
NUM_STAGE: tl.constexpr,
|
||||
SCALE_UE8M0: tl.constexpr,
|
||||
):
|
||||
tl.program_id(2)
|
||||
token_id = tl.program_id(1)
|
||||
hidden_dim_block_index = tl.program_id(0)
|
||||
|
||||
block_num_per_expert = tl.num_programs(1)
|
||||
|
||||
stride_input_0 = tl.cast(stride_input_0, dtype=tl.int64)
|
||||
stride_output_0 = tl.cast(stride_output_0, dtype=tl.int64)
|
||||
stride_input_1 = tl.cast(stride_input_1, dtype=tl.int64)
|
||||
stride_output_1 = tl.cast(stride_output_1, dtype=tl.int64)
|
||||
|
||||
offs_in_d = hidden_dim_block_index * BLOCK + tl.arange(0, BLOCK // 4)
|
||||
input_ptr_offs = input_ptr + offs_in_d
|
||||
output_ptr_offs = output_ptr + offs_in_d
|
||||
output_scale_offs = (output_scale_ptr +
|
||||
hidden_dim_block_index * stride_output_scale_0)
|
||||
|
||||
for token_index in tl.range(token_id,
|
||||
token_num_cur_expert,
|
||||
block_num_per_expert,
|
||||
num_stages=NUM_STAGE):
|
||||
output_s_int32 = 0
|
||||
for pack_index in tl.range(4):
|
||||
local_mask = offs_in_d + pack_index * 128
|
||||
act = tl.load(
|
||||
input_ptr_offs + token_index * stride_input_0 +
|
||||
pack_index * 128,
|
||||
mask=local_mask < size_k,
|
||||
other=0.0,
|
||||
).to(tl.float32)
|
||||
_absmax = tl.maximum(tl.max(tl.abs(act)), 1e-10)
|
||||
output_s = _absmax / fp8_max
|
||||
if SCALE_UE8M0:
|
||||
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
||||
output_q = tl.clamp(act / output_s, fp8_min,
|
||||
fp8_max).to(output_ptr.dtype.element_ty)
|
||||
output_s_int32 += ((output_s.to(tl.int32, bitcast=True) >> 23) <<
|
||||
(8 * pack_index))
|
||||
tl.store(
|
||||
output_ptr_offs + token_index * stride_output_0 +
|
||||
pack_index * 128,
|
||||
output_q,
|
||||
mask=local_mask < size_k,
|
||||
)
|
||||
tl.store(
|
||||
output_scale_offs + token_index * stride_output_scale_1,
|
||||
output_s_int32,
|
||||
)
|
||||
|
||||
|
||||
def per_token_quant_and_transform(
|
||||
input: torch.Tensor,
|
||||
quant_group_size: int = 128,
|
||||
scale_ue8m0: bool = True,
|
||||
):
|
||||
"""
|
||||
input shape [g, m, k]
|
||||
output shape [g, m, k // 2], dtype fp8
|
||||
output_scale [g, k // 4, m // 2 // 128], dtype int32
|
||||
quant_group_size int
|
||||
masked_m shape [g]
|
||||
"""
|
||||
|
||||
assert input.is_contiguous()
|
||||
assert len(input.shape) == 2
|
||||
assert input.shape[-1] % 2 == 0
|
||||
|
||||
# FP8 quantization parameters
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max = finfo.max
|
||||
fp8_min = -fp8_max
|
||||
|
||||
m, k = input.shape
|
||||
|
||||
# Create output
|
||||
output = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda")
|
||||
|
||||
# Create output scale
|
||||
alignment = 4
|
||||
scale_k = ceil_div(k, quant_group_size)
|
||||
m_padded = align(m, alignment)
|
||||
scale_k_padded = align(scale_k, alignment)
|
||||
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
|
||||
dtype=torch.int32,
|
||||
device='cuda')
|
||||
|
||||
# Get block/grid/stage/warp
|
||||
BLOCK_NUM_PER_EXPERT = 64
|
||||
|
||||
BLOCK = quant_group_size * 4
|
||||
num_warps = 1
|
||||
NUM_STAGES = 6
|
||||
hidden_dim_split_block_num = triton.cdiv(k, BLOCK)
|
||||
grid = (
|
||||
hidden_dim_split_block_num,
|
||||
BLOCK_NUM_PER_EXPERT,
|
||||
1,
|
||||
)
|
||||
_per_token_quant_and_transform_kernel[grid](
|
||||
input,
|
||||
*input.stride(),
|
||||
output,
|
||||
*output.stride(),
|
||||
output_scale,
|
||||
*output_scale.stride(),
|
||||
m,
|
||||
k,
|
||||
fp8_max,
|
||||
fp8_min,
|
||||
BLOCK=BLOCK,
|
||||
NUM_STAGE=NUM_STAGES,
|
||||
num_warps=num_warps,
|
||||
SCALE_UE8M0=scale_ue8m0,
|
||||
)
|
||||
output_scale = output_scale.transpose(0, 1)[:m, :]
|
||||
check_sf_layout(
|
||||
output_scale,
|
||||
m,
|
||||
k,
|
||||
(1, 128),
|
||||
num_groups=None,
|
||||
tma_stride_check=True,
|
||||
)
|
||||
return output, output_scale
|
||||
@ -8,6 +8,14 @@ def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def align(x: int, y: int) -> int:
|
||||
return ceil_div(x, y) * y
|
||||
|
||||
|
||||
def ceil_to_ue8m0(x: torch.Tensor):
|
||||
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
@ -33,6 +41,33 @@ def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x_view.size(2))
|
||||
|
||||
|
||||
def per_token_cast_to_fp8_e8m0(
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2 and x.size(1) % 128 == 0
|
||||
m, n = x.shape
|
||||
x_view = x.view(m, -1, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(
|
||||
m, n), sf
|
||||
|
||||
|
||||
def per_block_cast_to_fp8_e8m0(
|
||||
x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros((align(m, 128), align(n, 128)),
|
||||
dtype=x.dtype,
|
||||
device=x.device)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
sf = ceil_to_ue8m0(x_amax / 448.0)
|
||||
x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn)
|
||||
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(
|
||||
x_view.size(0), x_view.size(2))
|
||||
|
||||
|
||||
def calc_diff(x, y):
|
||||
x, y = x.double(), y.double()
|
||||
denominator = (x * x + y * y).sum()
|
||||
|
||||
@ -9,7 +9,8 @@ import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from _torch.helpers import per_block_cast_to_fp8
|
||||
from _torch.helpers import (per_block_cast_to_fp8, per_block_cast_to_fp8_e8m0,
|
||||
per_token_cast_to_fp8_e8m0)
|
||||
from mpi4py import MPI
|
||||
from mpi4py.futures import MPIPoolExecutor
|
||||
from utils.util import (skip_neither_ada_nor_hopper_unittest,
|
||||
@ -25,6 +26,8 @@ from tensorrt_llm._torch.modules.fused_moe import (BaseMoeRoutingMethod,
|
||||
VanillaMoE, WideEPMoE)
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_cute_dsl import \
|
||||
CuteDslFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_deepgemm import \
|
||||
DeepGemmFusedMoE
|
||||
from tensorrt_llm._torch.modules.fused_moe.fused_moe_wide_ep import \
|
||||
AlltoallMethodType
|
||||
from tensorrt_llm._torch.modules.gated_mlp import GatedMLP
|
||||
@ -379,6 +382,183 @@ def set_tensor_value_4(x, num_row, num_cols):
|
||||
x.copy_(repeated)
|
||||
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
|
||||
product(
|
||||
[torch.bfloat16],
|
||||
[72],
|
||||
[128, 256, 384, 512, 1024, 2048, 4096, 8192],
|
||||
[2560],
|
||||
[DefaultMoeRoutingMethod],
|
||||
),
|
||||
)
|
||||
def test_fused_moe_fp8_blockwise_deepgemm(dtype,
|
||||
num_experts,
|
||||
seq_len,
|
||||
hidden_size,
|
||||
RoutingMethodCls,
|
||||
mapping=None):
|
||||
SEQ_LEN = seq_len
|
||||
HIDDEN_SIZE = hidden_size
|
||||
INTERMEDIATE_SIZE = 256
|
||||
NUM_EXPERTS = num_experts
|
||||
TOP_K = 2
|
||||
|
||||
routing_method = RoutingMethodCls(top_k=TOP_K)
|
||||
|
||||
mapping = mapping or Mapping()
|
||||
mapping.rank = mpi_rank()
|
||||
torch.cuda.set_device(mapping.rank)
|
||||
torch.manual_seed(0)
|
||||
torch.cuda.manual_seed(0)
|
||||
|
||||
x = torch.randn((SEQ_LEN, HIDDEN_SIZE), dtype=dtype).cuda()
|
||||
# Note: we use some special values init x and weight, otherwise the test will false positive failed.
|
||||
set_tensor_value_2(x, SEQ_LEN, HIDDEN_SIZE)
|
||||
|
||||
x = x.cuda()
|
||||
router_logits = torch.randn((SEQ_LEN, NUM_EXPERTS), dtype=dtype).cuda()
|
||||
|
||||
weights = {}
|
||||
w3_w1_weight_scales = []
|
||||
w2_weight_scales = []
|
||||
for expert_id in range(NUM_EXPERTS):
|
||||
w1_weight = torch.randn(
|
||||
(INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE
|
||||
w2_weight = torch.randn((HIDDEN_SIZE, INTERMEDIATE_SIZE),
|
||||
dtype=dtype).cuda()
|
||||
w3_weight = torch.randn(
|
||||
(INTERMEDIATE_SIZE, HIDDEN_SIZE), dtype=dtype).cuda() / HIDDEN_SIZE
|
||||
set_tensor_value_3(w1_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||
set_tensor_value_4(w2_weight, HIDDEN_SIZE, INTERMEDIATE_SIZE)
|
||||
set_tensor_value_3(w3_weight, INTERMEDIATE_SIZE, HIDDEN_SIZE)
|
||||
|
||||
w1_weight_fp8, w1_weight_scale = per_block_cast_to_fp8_e8m0(w1_weight)
|
||||
w1_weight_fp8 = w1_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||
|
||||
w2_weight_fp8, w2_weight_scale = per_block_cast_to_fp8_e8m0(w2_weight)
|
||||
w2_weight_fp8 = w2_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||
|
||||
w3_weight_fp8, w3_weight_scale = per_block_cast_to_fp8_e8m0(w3_weight)
|
||||
w3_weight_fp8 = w3_weight_fp8.view(torch.float8_e4m3fn).cuda()
|
||||
|
||||
weights[f"{expert_id}.w1.weight"] = w1_weight_fp8
|
||||
weights[f"{expert_id}.w2.weight"] = w2_weight_fp8
|
||||
weights[f"{expert_id}.w3.weight"] = w3_weight_fp8
|
||||
weights[f"{expert_id}.w1.weight_scale_inv"] = w1_weight_scale
|
||||
weights[f"{expert_id}.w2.weight_scale_inv"] = w2_weight_scale
|
||||
weights[f"{expert_id}.w3.weight_scale_inv"] = w3_weight_scale
|
||||
weights[f"{expert_id}.w1.weight_scale"] = w1_weight_scale
|
||||
weights[f"{expert_id}.w2.weight_scale"] = w2_weight_scale
|
||||
weights[f"{expert_id}.w3.weight_scale"] = w3_weight_scale
|
||||
|
||||
w3_w1_weight_scales.append(
|
||||
torch.cat([w3_weight_scale, w1_weight_scale], dim=0))
|
||||
w2_weight_scales.append(w2_weight_scale)
|
||||
|
||||
w3_w1_weight_scales = torch.stack(w3_w1_weight_scales, dim=0).cuda()
|
||||
w2_weight_scales = torch.stack(w2_weight_scales, dim=0).cuda()
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8_BLOCK_SCALES)
|
||||
|
||||
fused_moe = DeepGemmFusedMoE(
|
||||
num_experts=NUM_EXPERTS,
|
||||
routing_method=routing_method,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
intermediate_size=INTERMEDIATE_SIZE,
|
||||
dtype=dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||
)
|
||||
fused_moe.cuda()
|
||||
fused_moe.load_weights([weights])
|
||||
|
||||
def swiglu_fused_moe(x):
|
||||
x, gate = x.chunk(2, dim=-1)
|
||||
return torch.nn.functional.silu(gate) * x
|
||||
|
||||
def grouped_gemm(a: torch.Tensor, b: torch.Tensor, a_sf: torch.Tensor,
|
||||
b_sf: torch.Tensor,
|
||||
offset_array: torch.Tensor) -> torch.Tensor:
|
||||
d = torch.empty((a.shape[0], b.shape[1]),
|
||||
device=b.device,
|
||||
dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
|
||||
for idx in range(offset_array.numel() - 1):
|
||||
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
|
||||
|
||||
num_groups, n, k_ = b.shape
|
||||
d = torch.empty((a.shape[0], b.shape[1]),
|
||||
device=b.device,
|
||||
dtype=torch.bfloat16)
|
||||
m_indices = torch.empty(a.shape[0], device=b.device, dtype=torch.int32)
|
||||
for idx in range(offset_array.numel() - 1):
|
||||
m_indices[offset_array[idx]:offset_array[idx + 1]] = idx
|
||||
|
||||
for g in range(num_groups):
|
||||
aa = a[offset_array[g]:offset_array[g + 1], :].to(torch.bfloat16)
|
||||
aa_sf = a_sf[offset_array[g]:offset_array[g + 1], :]
|
||||
aa_dq = aa * aa_sf.repeat_interleave(
|
||||
128, dim=1)[:aa.shape[0], :aa.shape[1]]
|
||||
bb = b[g, :, :].to(torch.bfloat16)
|
||||
bb_sf = b_sf[g, :, :]
|
||||
bb_dq = bb * bb_sf.repeat_interleave(128, dim=0).repeat_interleave(
|
||||
128, dim=1)[:bb.shape[0], :bb.shape[1]]
|
||||
d[offset_array[g]:offset_array[g + 1], :] = (aa_dq @ bb_dq.t())
|
||||
return d
|
||||
|
||||
token_selected_experts, token_final_scales = routing_method.apply(
|
||||
router_logits)
|
||||
t_idx = 0
|
||||
permuted_data_tensor = torch.empty((x.shape[0] * TOP_K, x.shape[1]),
|
||||
device=x.device,
|
||||
dtype=torch.bfloat16)
|
||||
expert_first_token_offset_tensor = torch.zeros(NUM_EXPERTS + 1,
|
||||
dtype=torch.int32)
|
||||
unpermute_map = []
|
||||
scales = []
|
||||
for e_idx in range(NUM_EXPERTS):
|
||||
for idx, token in enumerate(x):
|
||||
for i, selected_expert in enumerate(token_selected_experts[idx]):
|
||||
if e_idx == selected_expert:
|
||||
permuted_data_tensor[t_idx, :] = token
|
||||
unpermute_map.append(idx)
|
||||
scales.append(token_final_scales[idx, i])
|
||||
t_idx += 1
|
||||
expert_first_token_offset_tensor[e_idx + 1] = t_idx
|
||||
|
||||
act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(
|
||||
permuted_data_tensor)
|
||||
h1 = grouped_gemm(
|
||||
a=act_input_fp8,
|
||||
b=fused_moe.w3_w1_weight,
|
||||
a_sf=act_input_sf,
|
||||
b_sf=w3_w1_weight_scales,
|
||||
offset_array=expert_first_token_offset_tensor,
|
||||
)
|
||||
h2 = swiglu_fused_moe(h1)
|
||||
act_input_fp8, act_input_sf = per_token_cast_to_fp8_e8m0(h2)
|
||||
h3 = grouped_gemm(
|
||||
a=act_input_fp8,
|
||||
b=fused_moe.w2_weight,
|
||||
a_sf=act_input_sf,
|
||||
b_sf=w2_weight_scales,
|
||||
offset_array=expert_first_token_offset_tensor,
|
||||
)
|
||||
ref_output = torch.zeros_like(x)
|
||||
for token_idx, h3_token in enumerate(h3):
|
||||
original_idx = unpermute_map[token_idx]
|
||||
ref_output[original_idx, :] += h3_token * scales[token_idx]
|
||||
|
||||
with torch.inference_mode():
|
||||
output = fused_moe.forward(x, router_logits)
|
||||
|
||||
# compare
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(output, ref_output, rtol=1e-2, atol=0.1)
|
||||
|
||||
|
||||
@skip_non_hopper_unittest
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls",
|
||||
|
||||
@ -18,10 +18,48 @@ import sys
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _torch.helpers import calc_diff, per_block_cast_to_fp8
|
||||
from _torch.helpers import (calc_diff, per_block_cast_to_fp8,
|
||||
per_block_cast_to_fp8_e8m0,
|
||||
per_token_cast_to_fp8_e8m0)
|
||||
from utils.util import getSMVersion
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() != 100,
|
||||
reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
|
||||
(2048, 7168), (1024, 1024)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128, 4096],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
def test_fp8_block_scale_deep_gemm(dtype, m, k, n):
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype)
|
||||
b = torch.randn((n, k), device='cuda', dtype=dtype)
|
||||
|
||||
act_a_fp8, act_a_sf = per_token_cast_to_fp8_e8m0(a)
|
||||
act_b_fp8, act_b_sf = per_block_cast_to_fp8_e8m0(b)
|
||||
|
||||
output_expected = a @ b.t()
|
||||
import deep_gemm
|
||||
output = torch.empty((act_a_fp8.shape[0], act_b_fp8.shape[0]),
|
||||
device=act_a_fp8.device,
|
||||
dtype=torch.bfloat16)
|
||||
|
||||
deep_gemm.fp8_gemm_nt((act_a_fp8, act_a_sf), (act_b_fp8, act_b_sf), output)
|
||||
diff = calc_diff(output, output_expected)
|
||||
assert diff < 1e-2
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() != 100 and getSMVersion() != 89,
|
||||
reason="The test is for Blackwell and Ada only. Current SM is %d." %
|
||||
|
||||
@ -51,6 +51,9 @@ def test_pip_install():
|
||||
help="The wheel path")
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.environ.get("CUDA_HOME"):
|
||||
os.environ["CUDA_HOME"] = "/usr/local/cuda"
|
||||
|
||||
print("########## Install required system libs ##########")
|
||||
if not os.path.exists("/usr/local/mpi/bin/mpicc"):
|
||||
subprocess.check_call("apt-get -y install libopenmpi-dev", shell=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user