TensorRT-LLMs/tensorrt_llm/_torch/modules/linear.py
liji-nv dca6397d1e
feat: Introduce UB allocator for pytorch flow (#3257)
* Instead of allocating UserBuffers at beginning of runtime, UB buffers
  are now managed with global allocator. The allocator will dynamically
assign free UB buffer or allocate new buffer for torch tensor. It makes
userbuffers easier to use.

* In common usecase, the Userbuffers will be allocated correctly during
  warm up stage. There is no dynamic allocation during inference.

* UB fusion pattern is rewroten using the new UB Allocator. It contains
  following passes:

1. Fuse Quant with allreduce, replace with UB impl, and insert a
   copy_to_userbuffers. Currently the normal allreduce still does not
   support FP8 quant. So this need to be done in UB pass
2. Convert all supported allreduce with UB and insert copy_to_userbuffers.
3. Fuse op before ar with the copy_to_userbuffers. So the op directly
   writes to the userbuffer
4. Remove userbuffers finalize if the output is connect to another UB
   allreduce.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
2025-04-08 18:39:49 +08:00

574 lines
27 KiB
Python

import enum
import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.parameter import Parameter
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams
from ...models.modeling_utils import QuantConfig
from ..distributed import ParallelConfig, TensorParallelMode
from ..utils import Fp4QuantizedTensor
E2M1_MAX = 6.0
class WeightMode(str, enum.Enum):
# weight of a vanilla layer
VANILLA = 'vanilla'
# weight of a fused QKV linear layer
FUSED_QKV_LINEAR = 'fused_qkv_linear'
# weight of a fused gate and up linear layer
FUSED_GATE_UP_LINEAR = 'fused_gate_up_linear'
@dataclass(kw_only=True)
class WeightsLoadingConfig:
weight_mode: WeightMode = WeightMode.VANILLA
ignore_tensor_parallel: bool = False
def load_weight_shard(
weight,
tensor_parallel_size: int = 1,
tensor_parallel_rank: int = 0,
tensor_parallel_mode: Optional[TensorParallelMode] = None,
device: torch.device = torch.device('cpu'),
) -> torch.Tensor:
if isinstance(weight, torch.Tensor):
tensor_shape = weight.shape
def maybe_convert_to_torch_tensor(tensor: torch.Tensor,
indices: slice = None):
if indices == None:
# Avoid unnecessary copy
return tensor.to(device)
else:
return tensor[indices].to(device)
# WAR to check whether it is a safetensor slice since safetensor didn't register the type to the module
# safetensors slice, supports lazy loading, type(weight) is `builtin.PySafeSlice`
elif hasattr(weight, "get_shape"):
tensor_shape = weight.get_shape()
def maybe_convert_to_torch_tensor(
tensor, indices: Union[slice | tuple[slice]] = slice(None)):
return tensor[indices].to(device)
else:
raise ValueError(f'unsupported weight type: {type(weight)}')
if tensor_parallel_mode is None or tensor_parallel_size <= 1:
return maybe_convert_to_torch_tensor(weight)
split_dim = TensorParallelMode.split_dim(tensor_parallel_mode)
if len(tensor_shape) == 1 and split_dim == 1:
return maybe_convert_to_torch_tensor(weight)
width = tensor_shape[split_dim]
if width == 1:
return maybe_convert_to_torch_tensor(weight)
slice_width = math.ceil(width / tensor_parallel_size)
slice_start = tensor_parallel_rank * slice_width
slice_end = min((tensor_parallel_rank + 1) * slice_width, width)
slice_obj = [slice(None)] * len(tensor_shape)
slice_obj[split_dim] = slice(slice_start, slice_end)
return maybe_convert_to_torch_tensor(weight, tuple(slice_obj))
def load_weight_scales_fp8_qdq(weights: List[Dict]):
input_scale, weight_scale = [], []
for w in weights:
if "input_scale" in w:
input_scale.append(w["input_scale"][...].reshape([]))
if "weight_scale" in w:
weight_scale.append(w["weight_scale"][...].reshape([]))
return input_scale, weight_scale
def load_weight_scales_nvfp4(weights: List[Dict],
tp_size: int = 1,
tp_rank: int = 0,
tp_mode: Optional[TensorParallelMode] = None):
# For concatenated weights (qkv_proj / up_gate_proj), the global scaling factors and input scaling factors should be shared.
input_scale = None
weight_scale_2 = None
weight_scale = []
for w in weights:
if "input_scale" in w:
if input_scale is None:
input_scale = w["input_scale"][...]
else:
assert input_scale == w["input_scale"][
...], "The input_scale should be same for all the weights"
if "weight_scale" in w:
ws = load_weight_shard(w["weight_scale"], tp_size, tp_rank,
tp_mode).contiguous()
assert ws.dtype == torch.float8_e4m3fn # TODO: or e8m0 for mxfp4 recipe?
weight_scale.append(ws.view(fp4_utils.float4_sf_dtype))
if "weight_scale_2" in w:
if weight_scale_2 is None:
weight_scale_2 = w["weight_scale_2"][...]
else:
assert weight_scale_2 == w["weight_scale_2"][
...], "The weight_scale_2 should be same for all the weights"
# Compute scaling factor and alpha required by GEMM kernels
# TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32
alpha = input_scale.float() * weight_scale_2.float()
# modelopt ckpt stores amax/(448*6), convert to (448*6)/amax
input_scale = 1.0 / input_scale
return input_scale, weight_scale, alpha
class Linear(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
parallel_config: Optional[ParallelConfig] = None,
quant_config: Optional[QuantConfig] = None,
weights_loading_config: Optional[WeightsLoadingConfig] = None,
is_expert: bool = False,
skip_create_weights: bool = False,
use_custom_cublas_mm: bool = False):
from ..distributed import AllReduce
super().__init__()
self.has_bias = bias
self.dtype = dtype
self.parallel_config = parallel_config or ParallelConfig()
# could be modified later
self.quant_config = quant_config
self.weights_loading_config = weights_loading_config or WeightsLoadingConfig(
)
self.tp_size = self.parallel_config.tensor_parallel_size
self.tp_rank = self.parallel_config.tensor_parallel_rank
self.tp_mode = self.parallel_config.tensor_parallel_mode
local_in_features = in_features
local_out_features = out_features
if self.parallel_config.tensor_parallel_mode == TensorParallelMode.ROW:
assert in_features % self.tp_size == 0, (
f'in_features {in_features} must be divisible by tp_size {self.tp_size}'
)
local_in_features = in_features // self.tp_size
elif self.parallel_config.tensor_parallel_mode == TensorParallelMode.COLUMN:
assert out_features % self.tp_size == 0, (
f'out_features {out_features} must be divisible by tp_size {self.tp_size}'
)
local_out_features = out_features // self.tp_size
else:
assert self.parallel_config.tensor_parallel_mode is None, (
'unsupported tensor parallel mode: {self.parallel_config.tensor_parallel_mode}'
)
self.in_features = local_in_features
self.out_features = local_out_features
self.all_reduce = AllReduce(
self.parallel_config) if not is_expert else None
self._weights_created = False
self.is_expert = is_expert
self.use_custom_cublas_mm = use_custom_cublas_mm
if not skip_create_weights:
self.create_weights()
def create_weights(self):
if self._weights_created:
return
device = torch.device('cuda')
weight_shape = (self.out_features, self.in_features)
self.has_any_quant = False
self.has_fp8_qdq = False
self.has_fp8_block_scales = False
self.has_nv_fp4 = False
# only _create_weights, and load quantized weight directly.
if self.quant_config and self.quant_config.layer_quant_mode.has_any_quant(
):
self.has_any_quant = True
qc = self.quant_config
if qc.layer_quant_mode.has_fp8_qdq():
self.has_fp8_qdq = True
self.weight = Parameter(torch.empty(weight_shape,
dtype=torch.float8_e4m3fn,
device=device),
requires_grad=False)
self.weight_scale = Parameter(torch.tensor(1.,
dtype=torch.float32,
device=device),
requires_grad=False)
self.input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32,
device=device),
requires_grad=False)
self.inv_input_scale = Parameter(torch.tensor(
1., dtype=torch.float32, device=device),
requires_grad=False)
elif qc.layer_quant_mode.has_fp8_block_scales():
self.has_fp8_block_scales = True
self.weight = Parameter(torch.empty(weight_shape,
dtype=torch.float8_e4m3fn,
device=device),
requires_grad=False)
scale_shape = (math.ceil(self.out_features / 128),
math.ceil(self.in_features / 128))
self.weight_scale = Parameter(torch.empty(scale_shape,
dtype=torch.float32,
device=device),
requires_grad=False)
# Not really used for Gemm now.
# Only used to quantize output of FP8 attention.
self.input_scale = Parameter(torch.tensor(1.,
dtype=torch.float32,
device=device),
requires_grad=False)
self.inv_input_scale = Parameter(torch.tensor(
1., dtype=torch.float32, device=device),
requires_grad=False)
elif qc.layer_quant_mode.has_nvfp4():
self.has_nv_fp4 = True
self.scaling_vector_size = 16
assert self.in_features % self.scaling_vector_size == 0, f"in_features {self.in_features} must be divisible by scaling_vector_size {self.scaling_vector_size}"
# Quantized weights
self.weight = Parameter(torch.empty(
[self.out_features, self.in_features // 2],
dtype=fp4_utils.float4_e2m1x2,
device=device),
requires_grad=False)
# FP8 per-block scaling factors. dtype must be aligned with SF_DTYPE
# Padding is required. See computeSFSize in quantization.h
nrows = fp4_utils.pad_up(self.out_features, 128)
ncols = fp4_utils.pad_up(
self.in_features // self.scaling_vector_size, 4)
self.weight_scale = Parameter(torch.empty(
[nrows * ncols],
dtype=fp4_utils.float4_sf_dtype,
device=device),
requires_grad=False)
# FP32 per-tensor global scaling factor = 448*6/amax_input
self.input_scale = Parameter(torch.empty([1],
dtype=torch.float32,
device=device),
requires_grad=False)
self.inv_input_scale = Parameter(torch.empty(
[1], dtype=torch.float32, device=device),
requires_grad=False)
# (amax_input*amax_weight) / (448*6*448*6)
self.alpha = Parameter(torch.empty([1],
dtype=torch.float32,
device=device),
requires_grad=False)
else:
# TODO(zhenhuanc): support other quant mode
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')
else:
self.weight = Parameter(torch.empty(weight_shape,
dtype=self.dtype,
device=device),
requires_grad=False)
if self.has_bias:
self.bias = Parameter(torch.empty((self.out_features, ),
dtype=self.dtype,
device=device),
requires_grad=False)
else:
self.register_parameter("bias", None)
self._weights_created = True
def apply_linear(self, input, weight, bias):
if self.has_any_quant:
qc = self.quant_config
if self.has_fp8_qdq:
if input.dtype != torch.float8_e4m3fn:
qinput, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
input, self.input_scale)
else:
qinput = input
# This op does not support bias now.
output = torch.ops.trtllm.cublas_scaled_mm(
qinput,
weight.t(),
scale_a=self.input_scale,
scale_b=self.weight_scale,
bias=None,
out_dtype=self.dtype or input.dtype,
)
if bias is not None:
output = output + bias
elif self.has_fp8_block_scales:
if input.dtype == torch.float8_e4m3fn:
input = input.to(torch.bfloat16) * self.input_scale
assert input.dtype == torch.bfloat16
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, self.weight, act_input_sf, self.weight_scale)
elif self.has_nv_fp4:
if isinstance(input, Fp4QuantizedTensor):
act_fp4, act_sf = input.fp4_tensor, input.scaling_factor
else:
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, self.input_scale, self.scaling_vector_size,
False)
output = torch.ops.trtllm.nvfp4_gemm(act_fp4, self.weight,
act_sf, self.weight_scale,
self.alpha, False,
self.dtype)
else:
# TODO(zhenhuanc): support other quant mode
raise ValueError(f'unsupported quant mode: {qc.quant_mode}')
else:
# TODO: remove custom cublas_mm when default heuristics is good enough
if self.use_custom_cublas_mm:
output = torch.ops.trtllm.cublas_mm(input,
self.weight.t(),
bias,
out_dtype=None)
else:
output = F.linear(input, self.weight, bias)
return output
def forward(
self,
input: Union[torch.Tensor, Fp4QuantizedTensor],
*,
all_reduce_params: Optional[AllReduceParams] = None
) -> torch.Tensor:
from ..distributed import allgather
if self.tp_mode == TensorParallelMode.ROW:
bias = None if (self.tp_rank > 0) else self.bias
if not self.is_expert:
if self.tp_size > 1:
fuse_bias_into_all_reduce = (
bias is not None and all_reduce_params is not None
and (all_reduce_params.fusion_op
== AllReduceFusionOp.RESIDUAL_RMS_NORM))
if fuse_bias_into_all_reduce:
all_reduce_params.bias = bias
bias = None
else:
assert all_reduce_params is None or all_reduce_params.enable_allreduce is False, "Cannot fuse norm/residual/bias ops into allreduce op since we do not call allreduce op when tp_size is 1."
output = self.apply_linear(input, self.weight, bias)
output = self.all_reduce(
output,
all_reduce_params=all_reduce_params,
)
else:
output = self.apply_linear(input, self.weight, bias)
elif self.tp_mode == TensorParallelMode.COLUMN:
output = self.apply_linear(input, self.weight, self.bias)
if self.parallel_config.gather_output:
output = allgather(output, self.parallel_config)
else:
output = self.apply_linear(input, self.weight, self.bias)
return output
def load_weights(self, weights: List[Dict]):
assert self._weights_created
def copy(dst: Parameter, src: torch.Tensor):
assert dst.dtype == src.dtype, f"Incompatible dtype. dst: {dst.dtype}, src: {src.dtype}"
dst.data.copy_(src)
weight_mode = self.weights_loading_config.weight_mode
quant_mode = self.quant_config.quant_mode if self.quant_config else None
# load weight shard onto GPU to speed up operations on the shards
device = torch.device('cuda')
if weight_mode == WeightMode.VANILLA:
assert len(weights) == 1
weight = load_weight_shard(weights[0]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
copy(self.weight, weight)
if self.bias is not None:
bias = load_weight_shard(weights[0]['bias'], self.tp_size,
self.tp_rank, self.tp_mode, device)
copy(self.bias, bias)
if quant_mode:
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
copy(self.input_scale, input_scale[0])
copy(self.weight_scale, weight_scale[0])
self.inv_input_scale.data = 1.0 / self.input_scale
elif quant_mode.has_nvfp4():
input_scale, weight_scale, alpha = load_weight_scales_nvfp4(
weights,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
tp_mode=self.tp_mode)
assert len(weights) == 1
weight_scale = weight_scale[0]
# Swizzle weight scale
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale)
copy(self.input_scale, input_scale)
copy(self.weight_scale, weight_scale)
self.inv_input_scale.data = self.input_scale / E2M1_MAX
copy(self.alpha, alpha)
elif quant_mode.has_fp8_block_scales():
# `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe.
# Actually they hold identical values of data_amax / 448.
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
weight_scale = load_weight_shard(weights[0][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode, device)
copy(self.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy(self.input_scale, weights[0]["input_scale"])
self.inv_input_scale.data = 1.0 / self.input_scale
elif weight_mode == WeightMode.FUSED_QKV_LINEAR:
assert len(weights) == 3
q_weight = load_weight_shard(weights[0]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
k_weight = load_weight_shard(weights[1]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
v_weight = load_weight_shard(weights[2]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
if quant_mode:
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
copy(self.input_scale, max(input_scale))
copy(self.weight_scale, max(weight_scale))
q_weight = q_weight.to(self.dtype) * weight_scale[0]
k_weight = k_weight.to(self.dtype) * weight_scale[1]
v_weight = v_weight.to(self.dtype) * weight_scale[2]
elif quant_mode.has_nvfp4():
input_scale, weight_scale, alpha = load_weight_scales_nvfp4(
weights,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
tp_mode=self.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale)
copy(self.input_scale, input_scale)
copy(self.weight_scale, weight_scale)
copy(self.alpha, alpha)
elif quant_mode.has_fp8_block_scales():
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
q_scale = load_weight_shard(weights[0][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode).contiguous()
k_scale = load_weight_shard(weights[1][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode).contiguous()
v_scale = load_weight_shard(weights[2][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode).contiguous()
fused_fp8_block_scale = torch.cat(
(q_scale, k_scale, v_scale))
copy(self.weight_scale, fused_fp8_block_scale)
fused_weight = torch.cat((q_weight, k_weight, v_weight))
if quant_mode and quant_mode.has_fp8_qdq():
fused_weight = (fused_weight / self.weight_scale).to(
torch.float8_e4m3fn)
copy(self.weight, fused_weight)
if self.bias is not None:
q_bias = load_weight_shard(weights[0]['bias'], self.tp_size,
self.tp_rank, self.tp_mode, device)
k_bias = load_weight_shard(weights[1]['bias'], self.tp_size,
self.tp_rank, self.tp_mode, device)
v_bias = load_weight_shard(weights[2]['bias'], self.tp_size,
self.tp_rank, self.tp_mode, device)
copy(self.bias, torch.cat((q_bias, k_bias, v_bias)))
elif weight_mode == WeightMode.FUSED_GATE_UP_LINEAR:
assert len(weights) == 2
gate_weight = load_weight_shard(weights[0]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
up_weight = load_weight_shard(weights[1]['weight'], self.tp_size,
self.tp_rank, self.tp_mode, device)
if quant_mode:
if quant_mode.has_fp8_qdq():
input_scale, weight_scale = load_weight_scales_fp8_qdq(
weights)
copy(self.input_scale, max(input_scale))
copy(self.weight_scale, max(weight_scale))
gate_weight = gate_weight.to(self.dtype) * weight_scale[0]
up_weight = up_weight.to(self.dtype) * weight_scale[1]
elif quant_mode.has_nvfp4():
input_scale, weight_scale, alpha = load_weight_scales_nvfp4(
weights,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
tp_mode=self.tp_mode)
# Swizzle weight scales after concatenation
weight_scale = torch.cat(weight_scale, 0)
weight_scale = torch.ops.tensorrt_llm.nvfp4_block_scale_interleave(
weight_scale)
copy(self.input_scale, input_scale)
copy(self.weight_scale, weight_scale)
copy(self.alpha, alpha)
elif quant_mode.has_fp8_block_scales():
scale_name = "weight_scale_inv"
if scale_name not in weights[0]:
scale_name = "weight_scale"
left_scale = load_weight_shard(weights[0][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode, device)
right_scale = load_weight_shard(weights[1][scale_name],
self.tp_size, self.tp_rank,
self.tp_mode, device)
fused_scale = torch.cat([left_scale, right_scale], dim=0)
copy(self.weight_scale, fused_scale)
fused_weight = torch.cat((gate_weight, up_weight))
if quant_mode and quant_mode.has_fp8_qdq():
fused_weight = (fused_weight / self.weight_scale).to(
torch.float8_e4m3fn)
copy(self.weight, fused_weight)
if self.bias is not None:
gate_bias = load_weight_shard(weights[0]['bias'], self.tp_size,
self.tp_rank, self.tp_mode,
device)
up_bias = load_weight_shard(weights[1]['bias'], self.tp_size,
self.tp_rank, self.tp_mode, device)
copy(self.bias, torch.cat((up_bias, gate_bias)))
else:
raise ValueError(f'unsupported weight mode: {weight_mode}')