import fnmatch from typing import Union import torch from .._utils import get_init_params from ..layers import (MLP, Attention, ColumnLinear, Embedding, GatedMLP, LayerNorm, RmsNorm, RowLinear) from ..layers.moe import MixtureOfExperts from ..models.modeling_utils import LayerQuantConfig, QuantConfig from ..parameter import Parameter from .layers import (FP4Linear, FP4RowLinear, FP8Linear, FP8RowLinear, Fp8RowwiseAttention, Fp8RowwiseGatedMLP, Fp8RowwiseMLP, Fp8RowwiseRmsNorm, Int8SmoothQuantLinear, Int8SmoothQuantRowLinear, QServeAttention, QServeGatedMLP, QServeMLP, QServeRmsNorm, SmoothQuantAttention, SmoothQuantGatedMLP, SmoothQuantLayerNorm, SmoothQuantMLP, SmoothQuantRmsNorm, WeightOnlyGroupwiseQuantColumnLinear, WeightOnlyGroupwiseQuantRowLinear, WeightOnlyQuantColumnLinear, WeightOnlyQuantEmbedding, WeightOnlyQuantRowLinear) from .mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo, QuantMode def quantize_layers( model, quant_config: QuantConfig, quant_map, preprocess_init_params=None, ): exclude_modules = quant_config.exclude_modules or [ '*lm_head', '*router', '*vocab_embedding', '*position_embedding', '*block_embedding', '*shared_expert_gate', ] for name, module, parent in model.named_modules_with_parent(): module_name = name.rsplit('.', 1)[-1] is_excluded = False quant_cls = None # handle exclusion for exclude_module in exclude_modules: if fnmatch.fnmatchcase(name, exclude_module): is_excluded = True break # MoE modules are quantized on their constructor, so they must always # be re-created with the appropriate quant_mode. When excluded, # re-create with quant_mode 0. # We need to handle it specially, we may want to redesign MoE implementation if isinstance(module, MixtureOfExperts): quant_cls = type(module) elif not is_excluded: for cls in quant_map: if isinstance(module, cls): quant_cls = quant_map[cls] break if quant_cls: init_params = get_init_params(module, quant_cls) if isinstance(module, MixtureOfExperts): if is_excluded: quant_mode = QuantMode(0) else: quant_mode = quant_config.quant_mode init_params["quant_mode"] = quant_mode if "bias" in init_params and not isinstance(module, MixtureOfExperts): init_params["bias"] = init_params["bias"] is not None if isinstance(module, ColumnLinear): init_params[ "out_features"] = module.out_features * module.tp_size elif isinstance(module, RowLinear): init_params["in_features"] = module.in_features * module.tp_size if preprocess_init_params is not None: preprocess_init_params(init_params, name, module) quant_layer = quant_cls(**init_params) if parent is not None: setattr(parent, module_name, quant_layer) else: model = quant_layer setattr(model, 'quant_mode', quant_config.quant_mode) return model def weight_only_quantize(model, quant_config: QuantConfig, model_config=None): assert quant_config.quant_mode.is_weight_only() try: model_cfg = model.config except AttributeError: model_cfg = model_config quant_map = { ColumnLinear: WeightOnlyQuantColumnLinear, RowLinear: WeightOnlyQuantRowLinear, Embedding: WeightOnlyQuantEmbedding, } def preprocess_init_params(init_params, name, module): init_params["quant_mode"] = quant_config.quant_mode if isinstance(module, ColumnLinear): module_name = name.rsplit('.', 1)[-1] init_params["transb"] = module_name == "lm_head" if "tp_rank" in init_params: init_params["tp_rank"] = model_cfg.mapping.tp_rank model = quantize_layers( model, quant_config, quant_map, preprocess_init_params, ) return model def weight_only_groupwise_quantize(model, quant_config: QuantConfig, model_config=None): assert quant_config.quant_mode.is_weight_only() try: model_cfg = model.config except AttributeError: model_cfg = model_config quant_map = { ColumnLinear: WeightOnlyGroupwiseQuantColumnLinear, RowLinear: WeightOnlyGroupwiseQuantRowLinear, MixtureOfExperts: MixtureOfExperts, } def preprocess_init_params(init_params, name, module): init_params["group_size"] = quant_config.group_size init_params["pre_quant_scale"] = quant_config.pre_quant_scale init_params["zero"] = quant_config.has_zero_point init_params[ "use_w4a8_awq"] = quant_config.quant_algo == QuantAlgo.W4A8_AWQ init_params[ "use_int8_weight"] = quant_config.quant_algo == QuantAlgo.W8A16_GPTQ if "tp_rank" in init_params: init_params["tp_rank"] = model_cfg.mapping.tp_rank model = quantize_layers( model, quant_config, quant_map, preprocess_init_params, ) return model def smooth_quantize_ootb( model, quant_config: QuantConfig, ): quant_map = { ColumnLinear: Int8SmoothQuantLinear, RowLinear: Int8SmoothQuantRowLinear, } model = quantize_layers( model, quant_config, quant_map, ) return model def smooth_quantize_plugin(model, quant_mode): quant_map = { RmsNorm: SmoothQuantRmsNorm, LayerNorm: SmoothQuantLayerNorm, GatedMLP: SmoothQuantGatedMLP, MLP: SmoothQuantMLP, Attention: SmoothQuantAttention, } for name, layer, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if layer_name in ['ln_f', 'ln_embed']: continue quant_cls = None for cls in quant_map: if isinstance(layer, cls): quant_cls = quant_map[cls] break if quant_cls is None: continue init_params = get_init_params(layer, quant_cls) init_params["quant_mode"] = quant_mode if isinstance(layer, Attention): init_params[ "num_attention_heads"] = layer.num_attention_heads * layer.tp_size quant_layer = quant_cls(**init_params) if parent is not None: setattr(parent, layer_name, quant_layer) else: model = quant_layer setattr(model, 'quant_mode', quant_mode) return model def smooth_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_act_and_weight_quant() if quant_config.quant_algo in W8A8_SQ_PLUGIN_LIST: return smooth_quantize_plugin(model, quant_config.quant_mode) else: return smooth_quantize_ootb(model, quant_config) def fp8_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_fp8_qdq() quant_map = { ColumnLinear: FP8Linear, RowLinear: FP8RowLinear, MixtureOfExperts: MixtureOfExperts, } model = quantize_layers( model, quant_config, quant_map, ) return model def fp8_rowwise_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_fp8_rowwise() quant_cls_map = { RmsNorm: Fp8RowwiseRmsNorm, GatedMLP: Fp8RowwiseGatedMLP, MLP: Fp8RowwiseMLP, Attention: Fp8RowwiseAttention, } if quant_config.exclude_modules is None: exclude_modules = ['*ln_f', '*ln_embed'] else: exclude_modules = quant_config.exclude_modules def extract_layer_idx(name): ss = name.split('.') for s in ss: if s.isdigit(): return int(s) return None # Meta's LLaMA 3.1 recipe: # (1) Skip quantization for the first and last Transformer layers # (2) Skip quantization for the Attention layers if quant_config.use_meta_recipe: exclude_modules.extend(['*input_layernorm', '*attention']) for name, layer, parent in model.named_modules_with_parent(): module_name = name.rsplit('.', 1)[-1] if quant_config.use_meta_recipe: local_layer_idx = extract_layer_idx(name) mapping = model.config.mapping layers_range = mapping.pp_layers(model.config.num_hidden_layers) if mapping.is_first_pp_rank() and local_layer_idx == 0: continue if mapping.is_last_pp_rank( ) and local_layer_idx == len(layers_range) - 1: continue quant_cls = None for cls in quant_cls_map: if isinstance(layer, cls): quant_cls = quant_cls_map[cls] break if quant_cls is None: continue is_excluded = False for exclude_module in exclude_modules: if fnmatch.fnmatchcase(name, exclude_module): is_excluded = True break if is_excluded: continue init_params = get_init_params(layer, quant_cls) init_params["quant_mode"] = quant_config.quant_mode if isinstance(layer, Attention): init_params[ "num_attention_heads"] = layer.num_attention_heads * layer.tp_size quant_layer = quant_cls(**init_params, clamp_val=quant_config.clamp_val) if parent is not None: setattr(parent, module_name, quant_layer) else: model = quant_layer setattr(model, 'quant_mode', quant_config.quant_mode) return model # TODO: These functions should be moved to ModelOpt. def qserve_quantize_weight_per_group(linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor, s2_scales: torch.FloatTensor, s2_szeros: torch.FloatTensor, group_size: int) -> torch.CharTensor: out_features = linear_weight.shape[0] in_features = linear_weight.shape[1] # Step 1: Quantize the weights to int8 linear_weight = linear_weight.div( s1_scales.reshape(out_features, 1).to(linear_weight.device)) linear_weight = linear_weight.round() # assert linear_weight.min() >= -119 and linear_weight.max() <= 119, "Stage 1: Quantized weight out of range" # 119 is the "magic" number assert (linear_weight.min() >= -128 and linear_weight.max() <= 127), "Stage 1: Quantized weight out of range" # Step 2: Quantize the weights to int4 linear_weight = linear_weight.reshape(out_features, in_features // group_size, group_size) s2_szeros = s2_szeros.reshape(out_features, in_features // group_size, 1).to(torch.float16).to(linear_weight.device) s2_scales = s2_scales.reshape(out_features, in_features // group_size, 1).to(torch.float16).to(linear_weight.device) linear_weight = linear_weight.add(s2_szeros).div(s2_scales).round() assert (linear_weight.min() >= 0 and linear_weight.max() <= 15), "Stage 2: Quantized weight out of range" qweight = linear_weight.reshape(out_features, in_features).to(torch.int8) return qweight def qserve_quantize_weight_per_channel( linear_weight: torch.HalfTensor, s1_scales: torch.FloatTensor, s1_szeros: torch.FloatTensor) -> torch.CharTensor: out_features = linear_weight.shape[0] in_features = linear_weight.shape[1] # Step 1: Quantize the weights to int4 s1_scales = s1_scales.reshape(out_features, 1).to(linear_weight.device) s1_szeros = s1_szeros.reshape(out_features, 1).to(linear_weight.device) qweight = linear_weight.add(s1_szeros).div(s1_scales).round() assert (qweight.min() >= 0 and qweight.max() <= 15), "Quantized weight out of range" return qweight.reshape(out_features, in_features).to(torch.int8) # Pack the quantized weights, scales and zeros and apply the reordering required by QServe kernels. # Return: processed [qweight, s1_scales, s2_scales, s2_zeros] def qserve_pack_reorder_per_group(qweight: torch.CharTensor, s1_scales: torch.FloatTensor, s2_scales: torch.FloatTensor, s2_szeros: torch.FloatTensor, group_size): out_features = qweight.shape[0] in_features = qweight.shape[1] outputs = [] s1_scales = s1_scales.reshape(out_features).to(torch.float16) s2_szeros = s2_szeros.reshape(out_features, in_features // group_size).to(torch.int8) s2_scales = s2_scales.reshape(out_features, in_features // group_size).to(torch.int8) # Step 3: Pack the quantized weights to real quantized weights # ---- Repack the weight ---- # assert qweight.dtype == torch.int8 # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4) W_unpack_reorder = (qweight.reshape( out_features // 32, 2, 2, 8, in_features // 32, 2, 4, 4, ).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous()) W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4).contiguous().to(torch.int8)) # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous() # [16, 0, 17, 1, ...] W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[..., 0] W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32, in_features // 32, 32, 16) W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features // 2) outputs.append(W_unpack_repacked) # for the last dimension, organize as 0, 8, 16, 24, 1, 9, 17, 25, ... following the requirement of tensor core gemm # ---- Pack the scales ---- # outputs.append(s1_scales.reshape(out_features)) s2_scales = (s2_scales.reshape(out_features, in_features // group_size).transpose(0, 1).contiguous()) s2_scales = s2_scales.reshape(in_features // group_size, out_features // 32, 32) s2_scales = (s2_scales.reshape(in_features // group_size, out_features // 32, 4, 8).transpose(-2, -1).contiguous()) s2_scales = s2_scales.reshape(in_features // group_size, out_features).contiguous() outputs.append(s2_scales) # ---- Pack the zeros ---- # s2_szeros = (s2_szeros.reshape(out_features, in_features // group_size).transpose(0, 1).contiguous()) s2_szeros = s2_szeros.reshape(in_features // group_size, out_features // 32, 32) s2_szeros = (s2_szeros.reshape(in_features // group_size, out_features // 32, 4, 8).transpose(-2, -1).contiguous()) s2_szeros = (s2_szeros.reshape(in_features // group_size, out_features).contiguous()) # (q - s2_zeros) * s2_scales = q * s2_scales - s2_zeros * s2_scales, # We convert the s2_zeros -> -s2_zeros * s2_scales s2_szeros = (-s2_szeros).int() # It has been pre-scaled in DeepCompressor s2_szeros = s2_szeros.to(torch.int8) outputs.append(s2_szeros) return outputs def qserve_pack_reorder_per_channel(qweight: torch.CharTensor, s1_scales: torch.FloatTensor, s1_szeros: torch.FloatTensor): out_features = qweight.shape[0] in_features = qweight.shape[1] outputs = [] # ---- Repack the weight ---- # assert qweight.dtype == torch.int8 # pack to M // 32, K // 32, (8, 4), ([2], 2, 2, 4) W_unpack_reorder = (qweight.reshape( out_features // 32, 2, 2, 8, in_features // 32, 2, 4, 4, ).permute(0, 4, 3, 6, 1, 5, 2, 7).contiguous()) W_unpack_reorder = (W_unpack_reorder.permute(0, 1, 2, 3, 5, 6, 7, 4).contiguous()) # B_fp16_reorder = B_fp16_reorder[:, :, :, :, :, :, [3, 2, 1, 0]].contiguous() # [16, 0, 17, 1, ...] W_unpack_repacked = (W_unpack_reorder[..., 1] << 4) + W_unpack_reorder[..., 0] W_unpack_repacked = W_unpack_repacked.reshape(out_features // 32, in_features // 32, 32, 16) W_unpack_repacked = W_unpack_repacked.reshape(out_features, in_features // 2).contiguous() outputs.append(W_unpack_repacked) # ---- Pack the scales and zeros ---- # s1_scales = s1_scales.reshape(out_features).contiguous() outputs.append(s1_scales.half()) s1_szeros = s1_szeros.reshape(out_features).contiguous().half() outputs.append(s1_szeros) return outputs # TODO: Duplicates smooth_quantize and quantize_layers def qserve_quantize(model, quant_config: QuantConfig): quant_mode = quant_config.quant_mode assert quant_config.quant_mode.is_qserve_w4a8() quant_map = { RmsNorm: QServeRmsNorm, LayerNorm: QServeRmsNorm, GatedMLP: QServeGatedMLP, MLP: QServeMLP, Attention: QServeAttention, } for name, layer, parent in model.named_modules_with_parent(): layer_name = name.rsplit('.', 1)[-1] if layer_name in ['ln_f', 'ln_embed']: continue quant_cls = None for cls in quant_map: if isinstance(layer, cls): quant_cls = quant_map[cls] break if quant_cls is None: continue init_params = get_init_params(layer, quant_cls) init_params["quant_mode"] = quant_mode if isinstance(layer, Attention): init_params[ "num_attention_heads"] = layer.num_attention_heads * layer.tp_size quant_layer = quant_cls(**init_params) if parent is not None: setattr(parent, layer_name, quant_layer) else: model = quant_layer setattr(model, 'quant_mode', quant_mode) return model def fp4_quantize(model, quant_config: QuantConfig): assert quant_config.quant_mode.has_nvfp4() quant_map = { ColumnLinear: FP4Linear, RowLinear: FP4RowLinear, MixtureOfExperts: MixtureOfExperts, } model = quantize_layers( model, quant_config, quant_map, ) return model # Now consider the kv cache is enabled for all layers def kv_cache_quantize(model): for name, module in model.named_modules(): if isinstance(module, (Attention, SmoothQuantAttention, Fp8RowwiseAttention)): # for dequant module.kv_cache_scaling_factor = Parameter(shape=(1, ), dtype='float32') # for quant module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ), dtype='float32') return model def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]): for name, module, parent in model.named_modules_with_parent(): if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: layer_quant_mode = quant_config.layer_quant_mode(name) else: layer_quant_mode = quant_config.layer_quant_mode if layer_quant_mode == QuantMode(0): continue layer_quant_cfg = quant_config._get_quant_cfg(name) if layer_quant_mode.has_fp8_qdq(): module = fp8_quantize(module, layer_quant_cfg) elif layer_quant_mode.has_fp8_rowwise(): module = fp8_rowwise_quantize(module, layer_quant_cfg) elif layer_quant_mode.is_qserve_w4a8(): module = qserve_quantize(module, quant_config) elif layer_quant_mode.has_nvfp4(): module = fp4_quantize(module, layer_quant_cfg) elif layer_quant_mode.has_act_and_weight_quant(): module = smooth_quantize(module, layer_quant_cfg) elif layer_quant_mode.is_weight_only(): if layer_quant_mode.has_per_group_scaling(): module = weight_only_groupwise_quantize(module, layer_quant_cfg, model.config) else: module = weight_only_quantize(module, layer_quant_cfg, model.config) if parent is not None: # for per layer module_name = name.rsplit('.', 1)[-1] setattr(parent, module_name, module) else: # for all layer model = module break if quant_config.quant_mode.has_kv_cache_quant(): model = kv_cache_quantize(model) setattr(model, 'quant_mode', quant_config.quant_mode) return model