TensorRT-LLMs/tensorrt_llm/_torch/modules/linear.py
Yukun He 9c5b464fe0
[None][feat] Apply AutoTuner to fp8_block_scale_deep_gemm to trigger JIT ahead of time. (#7113)
Because deep_gemm.gp8_gemm_nt will trigger many JIT processes during the inference phase, we need to sweep these shapes ahead of time. Apply the AutoTuner framework to achieve this and retain the potential capability to tune the swap_ab flag.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
2025-08-25 10:48:31 +08:00

1705 lines
73 KiB
Python

from __future__ import annotations
import enum
import math
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.peft.lora.layer import LoraLayer
from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams,
AllReduceStrategy)
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization.functional import \
preprocess_weights_for_mixed_gemm
from tensorrt_llm.quantization.mode import QuantAlgo
from ..._utils import get_sm_version
from ...models.modeling_utils import QuantConfig
from ..utils import Fp4QuantizedTensor
class WeightMode(str, enum.Enum):
# weight of a vanilla layer
VANILLA = 'vanilla'
# weight of a fused QKV linear layer
FUSED_QKV_LINEAR = 'fused_qkv_linear'
# weight of a fused gate and up linear layer
FUSED_GATE_UP_LINEAR = 'fused_gate_up_linear'
@dataclass(kw_only=True)
class WeightsLoadingConfig:
weight_mode: WeightMode = WeightMode.VANILLA
ignore_tensor_parallel: bool = False
class TensorParallelMode(str, enum.Enum):
COLUMN = 'column'
ROW = 'row'
@classmethod
def split_dim(cls, mode):
return 1 if mode == cls.ROW else 0
# Helper to shard the corresponding per-channel activation scales
# Which shard along the dimension orthogonal to the weights
@classmethod
def flip(cls, mode):
return cls.ROW if mode == cls.COLUMN else cls.COLUMN
def load_weight_shard(
weight,
tensor_parallel_size: int = 1,
tensor_parallel_rank: int = 0,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
device: torch.device = torch.device('cpu'),
) -> torch.Tensor:
if isinstance(weight, torch.Tensor):
tensor_shape = weight.shape
def maybe_convert_to_torch_tensor(tensor: torch.Tensor,
indices: slice = None):
if indices is None:
# Avoid unnecessary copy
return tensor.to(device)
else:
return tensor[indices].to(device)
# WAR to check whether it is a safetensor slice since safetensor didn't register the type to the module
# safetensors slice, supports lazy loading, type(weight) is `builtin.PySafeSlice`
elif hasattr(weight, "get_shape"):
tensor_shape = weight.get_shape()
def maybe_convert_to_torch_tensor(
tensor, indices: Union[slice, tuple[slice]] = slice(None)):
return tensor[indices].to(device)
else:
raise ValueError(f'unsupported weight type: {type(weight)}')
if tensor_parallel_mode is None or tensor_parallel_size <= 1:
return maybe_convert_to_torch_tensor(weight)
split_dim = TensorParallelMode.split_dim(tensor_parallel_mode)
if len(tensor_shape) == 1 and split_dim == 1:
return maybe_convert_to_torch_tensor(weight)
width = tensor_shape[split_dim]
if width == 1:
return maybe_convert_to_torch_tensor(weight)
slice_width = math.ceil(width / tensor_parallel_size)
slice_start = tensor_parallel_rank * slice_width
slice_end = min((tensor_parallel_rank + 1) * slice_width, width)
slice_obj = [slice(None)] * len(tensor_shape)
slice_obj[split_dim] = slice(slice_start, slice_end)
return maybe_convert_to_torch_tensor(weight, tuple(slice_obj))
def copy_weight(dst: Parameter, src: torch.Tensor):
# TODO check that is it a reasonable change or not
if dst.dtype != src.dtype:
src = src.to(dst.dtype)
assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
dst.data.copy_(src)
def load_weights_vanilla_helper(module: Linear,
weights: List[Dict],
weight_transform=lambda x: x,
bias_transform=lambda x: x):
assert len(weights) == 1
device = torch.device('cuda')
weight = load_weight_shard(weights[0]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
if module.has_weight_only_quant:
# NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm
# we need to cast the weight to int8 first.
activation_dtype = torch.float8_e4m3fn if module.has_w4a8_awq else torch.float16
weight_dtype, _ = get_weight_dtype_and_id(module)
weight = preprocess_weights_for_mixed_gemm(
weight.T.to(torch.int8).contiguous().cpu(), weight_dtype,
activation_dtype).cuda().contiguous()
copy_weight(module.weight, weight_transform(weight))
if module.bias is not None:
bias = load_weight_shard(weights[0]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
copy_weight(module.bias, bias_transform(bias))
def load_weights_fused_qkv_helper(
module: Linear,
weights: List[Dict],
weight_transform=lambda x: x,
bias_transform=lambda x: x
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert len(weights) == 3
device = torch.device('cuda')
q_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
k_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
v_weight = load_weight_shard(weights[2]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
if module.bias is not None:
q_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
k_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
v_bias = load_weight_shard(weights[2]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
copy_weight(module.bias,
bias_transform(torch.cat((q_bias, k_bias, v_bias))))
return tuple(map(weight_transform, (q_weight, k_weight, v_weight)))
def load_weights_fused_gate_up_helper(
module: Linear,
weights: List[Dict],
weight_transform=lambda x: x,
bias_transform=lambda x: x) -> tuple[torch.Tensor, torch.Tensor]:
assert len(weights) == 2
device = torch.device('cuda')
gate_weight = load_weight_shard(weights[0]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
up_weight = load_weight_shard(weights[1]['weight'], module.tp_size,
module.tp_rank, module.tp_mode, device)
if module.bias is not None:
gate_bias = load_weight_shard(weights[0]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
up_bias = load_weight_shard(weights[1]['bias'], module.tp_size,
module.tp_rank, module.tp_mode, device)
copy_weight(module.bias, bias_transform(torch.cat(
(gate_bias, up_bias))))
return tuple(map(weight_transform, (gate_weight, up_weight)))
def get_weight_dtype_and_id(module: Linear) -> tuple[torch.dtype, int]:
"""
Get weight dtype and weight_id for weight only quantization mode.
Returns:
tuple[torch.dtype, int]: (weight_dtype, weight_id) where:
- weight_dtype: torch.int8 for INT8 weights, torch.quint4x2 for INT4 weights
- weight_id: 1 for INT8, 2 for INT4 (used for weight packing)
"""
assert module.quant_config is not None and module.quant_config.layer_quant_mode.is_weight_only(
), "This function should only be called when the module has weight-only quantization enabled."
if module.quant_config.layer_quant_mode.is_int8_weight_only():
return torch.int8, 1
elif module.quant_config.layer_quant_mode.is_int4_weight_only():
return torch.quint4x2, 2
else:
raise ValueError(
f"Unsupported quant_mode: {module.quant_config.layer_quant_mode}")
class LinearMethodBase(ABC):
"""
Base class for all linear methods.
"""
@abstractmethod
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype, *args,
**kwargs):
raise NotImplementedError
@abstractmethod
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor], *args, **kwargs):
raise NotImplementedError
def load_weights(self, module: Linear, weights: List[Dict],
weight_mode: WeightMode):
"""
Load weights from the checkpoint.
"""
if weight_mode == WeightMode.VANILLA:
self.load_weights_vanilla(module, weights)
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
self.load_weights_fused_qkv_linear(module, weights)
elif weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
self.load_weights_fused_gate_up_linear(module, weights)
else:
raise ValueError(f'unsupported weight mode: {weight_mode}')
def load_weight_scales(self, weights: List[Dict], *args, **kwargs):
"""
Load quantized weight scales from the checkpoint.
"""
@abstractmethod
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
"""
Load weights for the VANILLA weight mode.
"""
raise NotImplementedError
@abstractmethod
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
"""
Load weights for the FUSED_QKV_LINEAR weight mode.
"""
raise NotImplementedError
@abstractmethod
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
"""
Load weights for the FUSED_GATE_UP_LINEAR weight mode.
"""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
weight_shape = (out_features, in_features)
module.weight = Parameter(torch.empty(weight_shape, dtype=dtype),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if module.use_custom_cublas_mm:
output = torch.ops.trtllm.cublas_mm(input,
module.weight.t(),
bias,
out_dtype=None)
else:
output = F.linear(input, module.weight, bias)
return output
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
class FP8QDQLinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
weight_shape = (out_features, in_features)
module.weight = Parameter(torch.empty(weight_shape,
dtype=torch.float8_e4m3fn),
requires_grad=False)
module.weight_scale = Parameter(torch.tensor(1., dtype=torch.float32),
requires_grad=False)
module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32),
requires_grad=False)
module.inv_input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
cur_input_scale = module.input_scale
if input.dtype != torch.float8_e4m3fn:
if module.input_scale is not None and not module.force_dynamic_quantization:
# Static quantization
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
input, module.input_scale)
else:
# Dynamic quantization
qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
input)
cur_input_scale = cur_input_scale.to(torch.float32)
else:
qinput = input
# This op does not support bias now.
if module.enable_cuda_core and qinput.shape[0] <= 8:
# use cuda core for small m dimension
output = torch.ops.trtllm.cuda_scaled_mm(
qinput,
module.weight.t(),
scale_a=cur_input_scale,
scale_b=module.weight_scale,
bias=None,
out_dtype=module.dtype or input.dtype,
)
else:
output = torch.ops.trtllm.cublas_scaled_mm(
qinput,
module.weight.t(),
scale_a=cur_input_scale,
scale_b=module.weight_scale,
bias=None,
out_dtype=module.dtype or input.dtype,
)
if bias is not None:
output = output + bias
return output
def load_weight_scales(self, weights: List[Dict]):
input_scale, weight_scale = [], []
for w in weights:
if "input_scale" in w:
input_scale.append(w["input_scale"][...].reshape([]))
if "weight_scale" in w:
weight_scale.append(w["weight_scale"][...].reshape([]))
return input_scale, weight_scale
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, input_scale[0])
module.inv_input_scale.data = 1.0 / module.input_scale
else:
# Dynamic quantization
module.input_scale = None
module.inv_input_scale = None
copy_weight(module.weight_scale, weight_scale[0])
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, max(input_scale))
else:
# Dynamic quantization
module.input_scale = None
copy_weight(module.weight_scale, max(weight_scale))
q_weight = q_weight.to(module.dtype) * weight_scale[0]
k_weight = k_weight.to(module.dtype) * weight_scale[1]
v_weight = v_weight.to(module.dtype) * weight_scale[2]
fused_weight = torch.cat((q_weight, k_weight, v_weight))
if module.weight_scale.device != fused_weight.device:
module.weight_scale = Parameter(
module.weight_scale.data.to(fused_weight.device))
fused_weight = (fused_weight / module.weight_scale).to(
torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
input_scale, weight_scale = self.load_weight_scales(weights)
if len(input_scale) != 0:
# Static quantization
copy_weight(module.input_scale, max(input_scale))
else:
# Dynamic quantization
module.input_scale = None
copy_weight(module.weight_scale, max(weight_scale))
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
gate_weight = gate_weight.to(module.dtype) * weight_scale[0]
up_weight = up_weight.to(module.dtype) * weight_scale[1]
fused_weight = torch.cat((gate_weight, up_weight))
if module.weight_scale.device != fused_weight.device:
module.weight_scale = Parameter(
module.weight_scale.data.to(fused_weight.device))
fused_weight = (fused_weight / module.weight_scale).to(
torch.float8_e4m3fn)
copy_weight(module.weight, fused_weight)
class FP8RowwiseLinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
weight_shape = (out_features, in_features)
module.weight = Parameter(torch.empty(weight_shape,
dtype=torch.float8_e4m3fn),
requires_grad=False)
module.weight_scale = Parameter(torch.empty(out_features),
requires_grad=False)
# Not really used for Gemm now.
# Only used to quantize output of FP8 attention.
module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32),
requires_grad=False)
module.inv_input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# FP8 tensor inputs are from attention. Directly use ones as scale.
if input.dtype == torch.float8_e4m3fn:
qinput = input
cur_input_scale = torch.ones(input.shape[0],
device=input.device,
dtype=torch.float32)
else:
# Use dynamic per-token quantization for activation
qinput, cur_input_scale = torch.ops.tensorrt_llm.quantize_e4m3_activation(
input)
# This op does not support bias now.
output = torch.ops.trtllm.fp8_rowwise_gemm(
qinput,
module.weight,
cur_input_scale.float(),
module.weight_scale,
module.dtype or input.dtype,
)
if bias is not None:
output = output + bias
return output
def _get_scale_name(self, weights: List[Dict]):
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
# Actually they hold identical values of data_amax / 448.
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
return scale_name
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
load_weights_vanilla_helper(module, weights)
scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
scale_name = self._get_scale_name(weights)
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
copy_weight(module.weight_scale, fused_fp8_block_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
scale_name = self._get_scale_name(weights)
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat((left_scale, right_scale))
copy_weight(module.weight_scale, fused_scale)
class FP8BlockScalesLinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
weight_shape = (out_features, in_features)
module.weight = Parameter(torch.empty(weight_shape,
dtype=torch.float8_e4m3fn),
requires_grad=False)
scale_shape = (math.ceil(out_features / 128),
math.ceil(in_features / 128))
module.weight_scale = Parameter(torch.empty(scale_shape,
dtype=torch.float32),
requires_grad=False)
# Not really used for Gemm now.
# Only used to quantize output of FP8 attention.
module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32),
requires_grad=False)
module.inv_input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if input.dtype == torch.float8_e4m3fn:
input = input.to(torch.bfloat16) * module.input_scale
assert input.dtype == torch.bfloat16
if get_sm_version() == 100:
if module.use_cute_dsl_blockscaling_mm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf,
module.weight_scale)
else:
output = torch.ops.trtllm.fp8_swap_ab_gemm(
input,
module.weight,
module.weight_scale,
disable_ue8m0_cast=True,
)
else:
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, module.weight, act_input_sf, module.weight_scale)
if bias is not None:
output = output + bias
return output
def _get_scale_name(self, weights: List[Dict]):
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
# Actually they hold identical values of data_amax / 448.
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
return scale_name
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank,
module.tp_mode).squeeze()
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
scale_name = self._get_scale_name(weights)
q_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
k_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_fp8_block_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
scale_name = self._get_scale_name(weights)
left_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
copy_weight(module.weight, fused_weight)
copy_weight(module.weight_scale, fused_scale)
class NVFP4LinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
module.scaling_vector_size = 16
assert in_features % module.scaling_vector_size == 0, f"in_features {in_features} must be divisible by scaling_vector_size {module.scaling_vector_size}"
# Quantized weights
module.weight = Parameter(torch.empty([out_features, in_features // 2],
dtype=fp4_utils.float4_e2m1x2),
requires_grad=False)
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
# Padding is required. See computeSFSize in quantization.h
nrows = fp4_utils.pad_up(out_features, 128)
ncols = fp4_utils.pad_up(in_features // module.scaling_vector_size, 4)
module.weight_scale = Parameter(torch.empty(
[nrows * ncols], dtype=fp4_utils.float4_sf_dtype),
requires_grad=False)
# FP32 per-tensor global scaling factor = 448*6/amax_input
module.input_scale = Parameter(torch.empty([1], dtype=torch.float32),
requires_grad=False)
module.inv_input_scale = Parameter(torch.empty([1],
dtype=torch.float32),
requires_grad=False)
# (amax_input * amax_weight) / (448*6 * 448*6)
module.alpha = Parameter(torch.empty([1], dtype=torch.float32),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
if isinstance(input, Fp4QuantizedTensor):
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
elif isinstance(input, tuple):
act_fp4, act_sf = input
else:
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, module.input_scale, module.scaling_vector_size, False)
output = torch.ops.trtllm.nvfp4_gemm(act_fp4, module.weight, act_sf,
module.weight_scale, module.alpha,
module.dtype)
if bias is not None:
output = output + bias
return output
def load_weight_scales(self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None):
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
input_scale = None
weight_scale_2 = None
weight_scale = []
device = torch.device("cuda")
for w in weights:
if "input_scale" in w:
if input_scale is None:
input_scale = w["input_scale"][...]
else:
assert input_scale == w["input_scale"][
...], "The input_scale should be same for all the weights"
if "weight_scale" in w:
ws = load_weight_shard(w["weight_scale"],
tp_size,
tp_rank,
tp_mode,
device=device).contiguous()
assert ws.dtype == torch.float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
if "weight_scale_2" in w:
if weight_scale_2 is None:
weight_scale_2 = w["weight_scale_2"][...]
else:
assert weight_scale_2 == w["weight_scale_2"][
...], "The weight_scale_2 should be same for all the weights"
# Compute scaling factor and alpha required by GEMM kernels
# TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
alpha = input_scale.float() * weight_scale_2.float()
# modelopt ckpt stores amax/(448*6), convert to (448*6)/amax
input_scale = 1.0 / input_scale
return input_scale, weight_scale, alpha
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
input_scale, weight_scale, alpha = self.load_weight_scales(
weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
assert len(weights) == 1
weight_scale = weight_scale[0]
# Swizzle weight scale
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.weight_scale, weight_scale)
E2M1_MAX = 6.0
module.inv_input_scale.data = module.input_scale / E2M1_MAX
copy_weight(module.alpha, alpha)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
input_scale, weight_scales, alpha = self.load_weight_scales(
weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scales, 0)
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.weight_scale, weight_scale)
copy_weight(module.alpha, alpha)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
input_scale, weight_scales, alpha = self.load_weight_scales(
weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scales, 0)
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.weight_scale, weight_scale)
copy_weight(module.alpha, alpha)
class W4A8MXFP4FP8LinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
module.scaling_vector_size = 32
assert module.in_features % module.scaling_vector_size == 0, f"in_features {module.in_features} must be divisible by scaling_vector_size {module.scaling_vector_size}"
# Quantized weights
module.weight = Parameter(torch.empty(
[module.out_features, module.in_features // 2],
dtype=fp4_utils.float4_e2m1x2),
requires_grad=False)
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
# Padding is required. See computeSFSize in quantization.h
nrows = fp4_utils.pad_up(module.out_features, 128)
ncols = fp4_utils.pad_up(
module.in_features // module.scaling_vector_size, 4)
module.weight_scale = Parameter(torch.empty(
[nrows * ncols], dtype=fp4_utils.float4_sf_dtype),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
fp8_input, input_scale = torch.ops.tensorrt_llm.quantize_e4m3_per_tensor(
input)
input_scale = input_scale.to(torch.float32)
nrows = fp4_utils.pad_up(input.shape[0], 128)
ncols = fp4_utils.pad_up(input.shape[1] // module.scaling_vector_size,
4)
# 01111111 is 2^(127 - 127) = 1 in E8M0
module.fake_act_scale = torch.empty(
[nrows * ncols], dtype=torch.uint8,
device=fp8_input.device).fill_(127).view(fp4_utils.float4_sf_dtype)
output = torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(fp8_input, module.weight,
module.fake_act_scale,
module.weight_scale,
input_scale, module.dtype)
if bias is not None:
output = output + bias
return output
def load_weight_scales(self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None):
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
weight_scale = []
device = torch.device("cuda")
for w in weights:
if "weight_scale" in w:
ws = load_weight_shard(w["weight_scale"],
tp_size,
tp_rank,
tp_mode,
device=device).contiguous()
# Should be E8M0 for MXFP4
assert ws.dtype == torch.uint8
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
return weight_scale
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
weight_scale = self.load_weight_scales(weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
assert len(weights) == 1
weight_scale = weight_scale[0]
# Swizzle weight scale
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
copy_weight(module.weight, fused_weight)
weight_scale = self.load_weight_scales(weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
copy_weight(module.weight, fused_weight)
weight_scale = self.load_weight_scales(weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.trtllm.block_scale_interleave(weight_scale)
copy_weight(module.weight_scale, weight_scale)
class WeightOnlyQuantLinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool,
dtype: torch.dtype) -> None:
_, weight_id = get_weight_dtype_and_id(module)
# Quantized weights (int4 weights are packed into int8)
module.weight = Parameter(torch.empty(
(in_features, out_features // weight_id), dtype=torch.int8),
requires_grad=False)
module.weight_scale = Parameter(torch.empty((out_features),
dtype=dtype),
requires_grad=False)
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
weight_dtype, _ = get_weight_dtype_and_id(module)
bias = bias.contiguous() if bias is not None else None
output = torch.ops.trtllm.weight_only_quant_gemm(
input, module.weight, weight_dtype, module.weight_scale,
module.dtype)
return output
def load_weight_scales(
self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]:
device = torch.device("cuda")
q_weight_scale = load_weight_shard(weights[0]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
k_weight_scale = load_weight_shard(weights[1]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
v_weight_scale = load_weight_shard(weights[2]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale]
return weight_scales
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
device = torch.device('cuda')
weight_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device)
copy_weight(module.weight_scale, weight_scale)
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
weight_dtype, _ = get_weight_dtype_and_id(module)
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), weight_dtype,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
weight_scales = self.load_weight_scales(weights)
# Create concatenated weight scale tensor
cat_weight_scale = torch.cat(weight_scales, dim=0)
copy_weight(module.weight_scale, cat_weight_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
device = torch.device('cuda')
weight_dtype, _ = get_weight_dtype_and_id(module)
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), weight_dtype,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
left_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
right_scale = load_weight_shard(weights[1]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
fused_scale = torch.cat([left_scale, right_scale], dim=0)
copy_weight(module.weight_scale, fused_scale)
class W4A16_AWQ_LinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool,
dtype: torch.dtype) -> None:
# Quantized weights
module.weight = Parameter(torch.empty(
(in_features, out_features // 2),
dtype=torch.int8,
),
requires_grad=False)
group_size = module.quant_config.group_size
if in_features % group_size != 0:
raise ValueError(
f"in_features ({self.in_features}) must be divisible by group_size ({group_size}) "
f"for INT4 per-group quantization scale dimensions.")
module.weight_scale = Parameter(torch.empty(
(in_features // group_size, out_features), dtype=dtype),
requires_grad=False)
# NOTE: Not in all linear we have this tensor - pre_quant_scale is computed as an average and merged with the
# LayerNorm for QKV and Gate/Up projection layers when possible. we can see the tensor only for o_proj and down_proj
module.pre_quant_scale = None
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
if module.pre_quant_scale is not None:
input = input * module.pre_quant_scale
bias = bias.contiguous() if bias is not None else None
output = torch.ops.trtllm.finegrained_mixed_dtype_gemm(
input=input.to(module.dtype).contiguous(),
weight=module.weight,
scales=module.weight_scale,
group_size=module.quant_config.group_size,
has_zero_point=module.quant_config.has_zero_point,
output_dtype=module.dtype or input.dtype,
bias=bias,
zeros=None)
return output
def load_weight_scales(
self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None) -> List[torch.Tensor]:
device = torch.device("cuda")
q_weight_scale = load_weight_shard(weights[0]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
k_weight_scale = load_weight_shard(weights[1]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
v_weight_scale = load_weight_shard(weights[2]['weight_scale'],
tp_size,
tp_rank,
tp_mode,
device=device)
weight_scales = [q_weight_scale, k_weight_scale, v_weight_scale]
return weight_scales
def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
load_weights_vanilla_helper(module, weights)
# Use the same device as the weight tensor
# as we register pre_quant_scale after sharded model weights are moved to respective gpus
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)
module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=device)
weight_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device)
copy_weight(module.pre_quant_scale, pre_quant_scale)
copy_weight(module.weight_scale, weight_scale.T.contiguous())
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]) -> None:
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
weight_scales = self.load_weight_scales(weights)
# Create concatenated weight scale tensor
cat_weight_scale = torch.cat(weight_scales, dim=0).T.contiguous()
copy_weight(module.weight_scale, cat_weight_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]) -> None:
device = torch.device('cuda')
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float16).cuda().contiguous()
copy_weight(module.weight, fused_weight)
left_scale = load_weight_shard(weights[0]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
right_scale = load_weight_shard(weights[1]['weight_scale'],
module.tp_size, module.tp_rank,
module.tp_mode, device).contiguous()
fused_scale = torch.cat([left_scale, right_scale], dim=0).T.contiguous()
copy_weight(module.weight_scale, fused_scale)
class W4A8_AWQ_LinearMethod(LinearMethodBase):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
# Quantized weights
module.weight = Parameter(torch.empty(
(in_features, out_features // 2),
dtype=torch.int8,
),
requires_grad=False)
group_size = module.quant_config.group_size
if in_features % group_size != 0:
raise ValueError(
f"in_features ({module.in_features}) must be divisible by group_size ({group_size}) "
f"for INT4 per-group quantization scale dimensions.")
# NOTE: for FP8 activation, scales needs to be float16
module.weight_scale = Parameter(torch.empty(
(in_features // group_size, out_features), dtype=torch.float16),
requires_grad=False)
# Similar to W4A16 AWQ, not all linears will have this tensor
module.pre_quant_scale = None
module.input_scale = Parameter(torch.tensor(1., dtype=torch.float32),
requires_grad=False)
module.inv_input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32),
requires_grad=False)
module.alpha = Parameter(torch.empty([1], dtype=torch.float32),
requires_grad=False)
# WAR for CUDA graph. Mixed w4a8 gemm does not accept alpha in device buffer.
# Hence we prepare a separate plain float to be updated during the weight load.
module.alpha_value = 1.0
if bias:
module.bias = Parameter(torch.empty((out_features), dtype=dtype),
requires_grad=False)
else:
module.register_parameter("bias", None)
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
"""
modelopt flow for w4a8_awq:
1. multiply pre_quant_scale to input
2. quantize input to fp8 using input_scale
3. unpack_weights and multiply by weight_scales (int4 -> fp16)
4. divied by weight_scale_2 (fp16 -> fp8 to allow gemm in fp8).
5. apply gemm in fp8.
6. rescale using alpha which is input_scale * weight_scale_2
"""
if module.pre_quant_scale is not None:
input = input * module.pre_quant_scale
if input.dtype == torch.float8_e4m3fn:
quantized_input = input
else:
quantized_input, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
input, (module.input_scale))
bias = bias.contiguous() if bias is not None else None
output = torch.ops.trtllm.finegrained_mixed_dtype_gemm(
input=quantized_input.contiguous(),
weight=module.weight,
scales=module.weight_scale,
group_size=module.quant_config.group_size,
has_zero_point=module.quant_config.has_zero_point,
output_dtype=module.dtype
or input.dtype, # NOTE: output_dtype can only be bf16/fp16 for W4A8
alpha=module.alpha_value,
bias=bias,
zeros=None)
return output
def load_weight_scales_w4a8(self,
weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None):
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
input_scale = None
weight_scale_2 = None
weight_scale = []
device = torch.device("cuda")
for w in weights:
if "input_scale" in w:
if input_scale is None:
input_scale = w["input_scale"][...]
else:
assert input_scale == w["input_scale"][
...], "The input_scale should be same for all the weights"
if "weight_scale" in w:
ws = load_weight_shard(w["weight_scale"],
tp_size,
tp_rank,
tp_mode,
device=device)
weight_scale.append(ws.to(torch.float16))
if "weight_scale_2" in w:
if weight_scale_2 is None:
weight_scale_2 = w["weight_scale_2"][...]
else:
assert weight_scale_2 == w["weight_scale_2"][
...], "The weight_scale_2 should be same for all the weights"
# Compute scaling factor and alpha required by GEMM kernels (rescale the gemm output in fp8)
alpha = (input_scale.float() * weight_scale_2.float())
return input_scale, weight_scale, alpha, weight_scale_2
def load_weights_vanilla(self, module: Linear, weights: List[Dict]):
load_weights_vanilla_helper(module, weights)
# Use the same device as the weight tensor
# as we register pre_quant_scale after sharded model weights are moved to respective gpus
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)
assert pre_quant_scale.dtype == module.dtype
module.pre_quant_scale = Parameter(
torch.empty((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=device)
copy_weight(module.pre_quant_scale, pre_quant_scale)
input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8(
weights=weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
assert len(weight_scale) == 1, "there should be only one weight scale"
weight_scale = (weight_scale[0].T / weight_scale_2).contiguous()
copy_weight(module.weight_scale, weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.alpha, alpha)
module.alpha_value = alpha.item()
module.inv_input_scale.data = 1.0 / module.input_scale
def load_weights_fused_qkv_linear(self, module: Linear,
weights: List[Dict]):
q_weight, k_weight, v_weight = load_weights_fused_qkv_helper(
module, weights)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float8_e4m3fn).cuda().contiguous()
copy_weight(module.weight, fused_weight)
input_scale, weight_scales, alpha, weight_scale_2 = self.load_weight_scales_w4a8(
weights=weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
# Create concatenated weight scale tensor
cat_weight_scale = (torch.cat(weight_scales, dim=0).T /
weight_scale_2).contiguous()
copy_weight(module.weight_scale, cat_weight_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.alpha, alpha)
module.alpha_value = alpha.item()
# NOTE: pre_quant_scale is the same for q,k,v since modelopt checks which layer shared the same input and create an avg pre_quant_scale
# Usually when modelopt exports the quantized model, pre_quant_Scale is fused in the layer norm (this case relevant if fused is disabled - modelopt internal)
if "pre_quant_scale" in weights[0].keys():
# Use the same device as the weight tensor
# as we register pre_quant_scale after sharded model weights are moved to respective gpus
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)
module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=torch.device('cuda'))
copy_weight(module.pre_quant_scale, pre_quant_scale)
def load_weights_fused_gate_up_linear(self, module: Linear,
weights: List[Dict]):
gate_weight, up_weight = load_weights_fused_gate_up_helper(
module, weights)
fused_weight = torch.cat((gate_weight, up_weight))
fused_weight = preprocess_weights_for_mixed_gemm(
fused_weight.to(torch.int8).T.contiguous().cpu(), torch.quint4x2,
torch.float8_e4m3fn).cuda().contiguous()
copy_weight(module.weight, fused_weight)
input_scale, weight_scale, alpha, weight_scale_2 = self.load_weight_scales_w4a8(
weights=weights,
tp_size=module.tp_size,
tp_rank=module.tp_rank,
tp_mode=module.tp_mode)
fused_scale = (torch.cat(weight_scale, dim=0).T /
weight_scale_2).contiguous()
copy_weight(module.weight_scale, fused_scale)
copy_weight(module.input_scale, input_scale)
copy_weight(module.alpha, alpha)
module.alpha_value = alpha.item()
if "pre_quant_scale" in weights[0].keys():
# Use the same device as the weight tensor
# as we register pre_quant_scale after sharded model weights are moved to respective gpus
device = module.weight.device
pre_quant_scale = load_weight_shard(
weights[0]["pre_quant_scale"],
module.tp_size,
module.tp_rank,
# pre_quant_scale applies to activation as opposed to weight, so flip tp_mode the other way around
TensorParallelMode.flip(module.tp_mode),
device,
)
# NOTE:Create this tensor in load_weights, since not all layer have this tensor and memory is not allocated for it (same as W4A16)
module.pre_quant_scale = Parameter(
torch.ones((module.in_features, ), dtype=pre_quant_scale.dtype),
requires_grad=False).to(device=torch.device('cuda'))
copy_weight(module.pre_quant_scale, pre_quant_scale)
class W4A8MXFP4MXFP8LinearMethod(W4A8MXFP4FP8LinearMethod):
def create_weights(self, module: Linear, in_features: int,
out_features: int, bias: bool, dtype: torch.dtype):
super().create_weights(module, in_features, out_features, bias, dtype)
module.scale_one = torch.tensor([1.0], dtype=torch.float32).cuda()
def apply(self, module: Linear, input: torch.Tensor,
bias: Optional[torch.Tensor]):
# requires the swizzled block scales.
fp8_input, input_scales = torch.ops.trtllm.mxfp8_quantize(input, True)
output = torch.ops.trtllm.w4a8_mxfp4_fp8_gemm(fp8_input, module.weight,
input_scales,
module.weight_scale,
module.scale_one,
module.dtype)
if bias is not None:
output = output + bias
return output
def get_quant_method(quant_config: Optional[QuantConfig] = None):
if quant_config is None or not quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True):
return UnquantizedLinearMethod()
if quant_config.layer_quant_mode.has_fp8_qdq():
return FP8QDQLinearMethod()
if quant_config.layer_quant_mode.has_fp8_rowwise():
return FP8RowwiseLinearMethod()
if quant_config.layer_quant_mode.has_fp8_block_scales():
return FP8BlockScalesLinearMethod()
if quant_config.layer_quant_mode.has_nvfp4():
return NVFP4LinearMethod()
if quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
return W4A8MXFP4FP8LinearMethod()
if quant_config.layer_quant_mode.is_weight_only(
) and not quant_config.layer_quant_mode.has_per_group_scaling():
return WeightOnlyQuantLinearMethod()
if quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and quant_config.quant_algo == QuantAlgo.W4A16_AWQ:
return W4A16_AWQ_LinearMethod()
if quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and quant_config.quant_algo == QuantAlgo.W4A8_AWQ:
return W4A8_AWQ_LinearMethod()
if quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8():
return W4A8MXFP4MXFP8LinearMethod()
raise ValueError(f'unsupported quant mode: {quant_config.quant_mode}')
class Linear(nn.Module):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
mapping: Optional[Mapping] = None,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
gather_output: bool = False, # COLUMN parallel only
quant_config: Optional[QuantConfig] = None,
weights_loading_config: Optional[WeightsLoadingConfig] = None,
reduce_output: bool = True, # ROW parallel only
skip_create_weights_in_init: bool = False,
use_custom_cublas_mm: bool = False,
lora: Optional[LoraLayer] = None,
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO,
force_dynamic_quantization: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
):
from ..distributed import AllReduce
super().__init__()
self.has_bias = bias
self.dtype = dtype
self.mapping = mapping or Mapping()
# could be modified later
self.quant_config = quant_config
self.weights_loading_config = weights_loading_config or WeightsLoadingConfig(
)
self.tp_size = self.mapping.tp_size
self.tp_rank = self.mapping.tp_rank
self.tp_mode = tensor_parallel_mode
self.gather_output = gather_output
self.force_dynamic_quantization = force_dynamic_quantization
self.use_cute_dsl_blockscaling_mm = use_cute_dsl_blockscaling_mm
local_in_features = in_features
local_out_features = out_features
if self.tp_mode == TensorParallelMode.ROW:
assert in_features % self.tp_size == 0, (
f'in_features {in_features} must be divisible by tp_size {self.tp_size}'
)
local_in_features = in_features // self.tp_size
elif self.tp_mode == TensorParallelMode.COLUMN:
assert out_features % self.tp_size == 0, (
f'out_features {out_features} must be divisible by tp_size {self.tp_size}'
)
local_out_features = out_features // self.tp_size
else:
assert self.tp_mode is None, (
'unsupported tensor parallel mode: {self.tp_mode}')
self.in_features = local_in_features
self.out_features = local_out_features
self.all_reduce = AllReduce(mapping=self.mapping,
strategy=allreduce_strategy,
dtype=self.dtype) if reduce_output else None
self._weights_created = False
self.reduce_output = reduce_output
self.use_custom_cublas_mm = use_custom_cublas_mm
self.lora = lora
self.enable_cuda_core = False
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability(
torch.device('cuda:0'))
# enable cuda core for sm89
self.enable_cuda_core = capability[0] == 8 and capability[1] == 9
if not skip_create_weights_in_init:
self.create_weights()
def get_quant_method(self, quant_config: Optional[QuantConfig] = None):
return get_quant_method(quant_config)
def create_weights(self):
if self._weights_created:
return
self.quant_method = self.get_quant_method(self.quant_config)
self.quant_method.create_weights(self, self.in_features,
self.out_features, self.has_bias,
self.dtype)
self._weights_created = True
@property
def has_any_quant(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant(
exclude_kv_cache=True)
@property
def has_fp8_qdq(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_qdq(
)
@property
def has_fp8_rowwise(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_rowwise(
)
@property
def has_fp8_block_scales(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_fp8_block_scales(
)
@property
def has_nvfp4(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
)
@property
def has_weight_only_quant(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_weight_only(
)
@property
def has_w4a16_awq(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and self.quant_config.quant_algo == QuantAlgo.W4A16_AWQ
@property
def has_w4a8_awq(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
) and self.quant_config.quant_algo == QuantAlgo.W4A8_AWQ
@property
def has_w4a8_mxfp4_fp8(self):
assert self._weights_created
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8(
)
def apply_linear(self,
input,
bias,
lora_params: Optional[dict] | None = None,
layer_idx: Optional[int] | None = None):
output = self.quant_method.apply(self, input, bias)
if self.lora is not None and bool(lora_params):
lora_result = self.lora(input, lora_params, layer_idx)
if lora_result is not None:
output = output + lora_result
return output
def _maybe_fuse_bias_into_allreduce(
self,
bias: Optional[torch.Tensor],
all_reduce_params: Optional[AllReduceParams] = None,
) -> bool:
if self.tp_size > 1:
fuse_bias_into_all_reduce = (
bias is not None and all_reduce_params is not None
and (all_reduce_params.fusion_op
== AllReduceFusionOp.RESIDUAL_RMS_NORM))
if fuse_bias_into_all_reduce:
all_reduce_params.bias = bias
return True
else:
assert all_reduce_params is None or all_reduce_params.enable_allreduce is False, "Cannot fuse norm/residual/bias ops into allreduce op since we do not call allreduce op when tp_size is 1."
return False
def forward(
self,
input: Union[torch.Tensor, Fp4QuantizedTensor],
*,
all_reduce_params: Optional[AllReduceParams] = None,
lora_params: Optional[dict] = None,
layer_idx: Optional[int] = None,
) -> torch.Tensor:
if self.tp_mode == TensorParallelMode.ROW:
bias = None if (self.tp_rank > 0) else self.bias
if self.reduce_output:
fuse_bias = self._maybe_fuse_bias_into_allreduce(
bias, all_reduce_params)
bias = None if fuse_bias else bias
output = self.apply_linear(input, bias, lora_params, layer_idx)
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
else:
output = self.apply_linear(input, bias, lora_params, layer_idx)
elif self.tp_mode == TensorParallelMode.COLUMN:
output = self.apply_linear(input, self.bias, lora_params, layer_idx)
if self.gather_output:
from ..distributed import allgather
output = allgather(output, self.mapping)
else:
output = self.apply_linear(input, self.bias, lora_params, layer_idx)
return output
def load_weights(self, weights: List[Dict]):
assert self._weights_created
weight_mode = self.weights_loading_config.weight_mode
self.quant_method.load_weights(self, weights, weight_mode)