From 7910d4d2a9b9cee1bc8c3b7d42bd0a050c555e2b Mon Sep 17 00:00:00 2001 From: Frida Hou <201670829+Fridah-nv@users.noreply.github.com> Date: Fri, 30 Jan 2026 23:07:24 -0800 Subject: [PATCH] [#8242][feat] Add int4 GPTQ support for AutoDeploy (#8248) Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- .../auto_deploy/model_registry/models.yaml | 2 + .../_torch/auto_deploy/config/default.yaml | 2 + .../auto_deploy/custom_ops/torch_quant.py | 95 ++++++++++ .../auto_deploy/models/quant_config_reader.py | 25 ++- .../transform/library/quantization.py | 107 ++++++++++- .../unit/singlegpu/custom_ops/test_gptq_op.py | 178 ++++++++++++++++++ .../library/test_quantization.py | 114 ++++++++++- 7 files changed, 516 insertions(+), 7 deletions(-) create mode 100644 tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_gptq_op.py diff --git a/examples/auto_deploy/model_registry/models.yaml b/examples/auto_deploy/model_registry/models.yaml index 61dfed1ab1..879a5eb0c6 100644 --- a/examples/auto_deploy/model_registry/models.yaml +++ b/examples/auto_deploy/model_registry/models.yaml @@ -5,6 +5,8 @@ models: yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] - name: Qwen/Qwen2.5-0.5B-Instruct yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] +- name: Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4 + yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] - name: Qwen/Qwen3-0.6B yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml'] # DISABLED: TorchDynamo compilation error - fake tensor dispatch failure diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index ff5a32eac8..948e2f8139 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -63,6 +63,8 @@ transforms: stage: pattern_matcher quantize_int4_linear_from_config: stage: pattern_matcher + quantize_int4_gptq_linear_from_config: + stage: pattern_matcher quantize_fp8_linear_from_config: stage: pattern_matcher quantize_nvfp4_linear_from_config: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py index 3fec7800d4..78a1fe5d83 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py @@ -319,3 +319,98 @@ def _fake( N_half = weight_quantized.shape[-2] N = N_half * 2 return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device) + + +@torch.library.custom_op("auto_deploy::torch_fake_quant_int4_gptq_linear", mutates_args=()) +def torch_fake_quant_int4_gptq_linear( + input: torch.Tensor, # [..., K] + weight_quantized: torch.Tensor, # qweight [K/8, N] int32 (packed) + bias: Optional[torch.Tensor], # [N] or None + input_scale: List[torch.Tensor], # unused for GPTQ + weight_scale: List[torch.Tensor], # GPTQ scales [G, N] + input_zp: List[torch.Tensor], # unused for GPTQ + weight_zp: List[torch.Tensor], # GPTQ qzeros [G, N/8] int32 +) -> torch.Tensor: + """ + GPTQ INT4 linear with compatible signature to other quant ops. + - weight_quantized: qweight [K/8, N] packed int32 + - weight_scale[0]: scales [G, N] + - weight_zp[0]: qzeros [G, N/8] packed int32 + """ + PACK_FACTOR = 8 + MAXQ = 15 + dequant_dtype = torch.int8 + + qweight = weight_quantized + scales = _expect_single_scale(weight_scale, "weight_scale") + qzeros = _expect_single_scale(weight_zp, "weight_zp") + + dev = qweight.device + input_shape = input.shape + in_features = input_shape[-1] + + if qweight.dim() != 2: + raise RuntimeError("qweight must be 2D [K/8, N]") + K = qweight.size(0) * PACK_FACTOR + N = qweight.size(1) + + if scales.dim() != 2 or scales.size(1) != N: + raise RuntimeError(f"scales must be [G, N={N}]") + G = scales.size(0) + + if K % G != 0: + raise RuntimeError(f"K ({K}) must be divisible by G ({G})") + block_size = K // G + + if qzeros.dim() != 2 or qzeros.size(0) != G or qzeros.size(1) * PACK_FACTOR != N: + raise RuntimeError(f"qzeros must be [G={G}, N/8={N // 8}]") + + # Reshape input to 2D if needed + x_2d = input.reshape(-1, in_features) + + # Build g_idx and shift tables + g_idx = torch.arange(K, device=dev, dtype=torch.int32) // block_size # [K] + wf = torch.arange(PACK_FACTOR, device=dev, dtype=torch.int32) * 4 # [8] + wf_unsqueeze_zero = wf.view(1, 1, PACK_FACTOR) # [1,1,8] + wf_unsqueeze_neg_one = wf.view(1, PACK_FACTOR, 1) # [1,8,1] + + zeros = torch.bitwise_right_shift( + torch.unsqueeze(qzeros, 2).expand(-1, -1, PACK_FACTOR), + wf_unsqueeze_zero, + ).to(dequant_dtype) + zeros = torch.bitwise_and(zeros, MAXQ).reshape(scales.shape) + + weight = torch.bitwise_and( + torch.bitwise_right_shift( + torch.unsqueeze(qweight, 1).expand(-1, PACK_FACTOR, -1), + wf_unsqueeze_neg_one, + ).to(dequant_dtype), + MAXQ, + ) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + weights = (scales[g_idx.long()] * (weight - zeros[g_idx.long()])).to(input.dtype) + + out = torch.matmul(x_2d, weights) + + if bias is not None: + out = out + bias + + # Reshape output back to match input batch dimensions + out = out.reshape(*input_shape[:-1], N) + + return out + + +@torch_fake_quant_int4_gptq_linear.register_fake +def torch_fake_quant_int4_gptq_linear_fake( + input: torch.Tensor, + weight_quantized: torch.Tensor, + bias: Optional[torch.Tensor], + input_scale: List[torch.Tensor], + weight_scale: List[torch.Tensor], + input_zp: List[torch.Tensor], + weight_zp: List[torch.Tensor], +) -> torch.Tensor: + N = weight_quantized.size(1) + return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device) diff --git a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py index 9a16e0b972..5ec9d69627 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py +++ b/tensorrt_llm/_torch/auto_deploy/models/quant_config_reader.py @@ -146,6 +146,8 @@ class HFQuantConfigReader(QuantConfigReader): Quantization reader that process transformers.quantizers.HFQuantizer functionality """ + _ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens") + def __init__(self): super().__init__() self._hf_quantizer = None @@ -156,6 +158,10 @@ class HFQuantConfigReader(QuantConfigReader): if not qconf: raise ValueError("HF quantization_config not found.") + # Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case + excludes = qconf.get("exclude_modules", []) + qconf["exclude_modules"] = excludes + [n for n in self._ALWAYS_EXCLUDE if n not in excludes] + self._quant_config = qconf from transformers.quantizers import AutoHfQuantizer @@ -177,12 +183,23 @@ class HFQuantConfigReader(QuantConfigReader): if not isinstance(qconf, dict): return None - # TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4, other hf quantizers + # TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4 and INT4-GPTQ, other hf quantizers # should have similar workflow and will be added to the pipeline quant_method = str(qconf.get("quant_method", "")).lower() - if quant_method != "mxfp4": + if quant_method not in ["mxfp4", "gptq"]: return None + # Validate GPTQ config: currently only INT4 with group_size=128 is supported + if quant_method == "gptq": + bits = qconf.get("bits") + group_size = qconf.get("group_size") + if bits != 4: + raise ValueError(f"GPTQ quantization only supports bits=4, got bits={bits}") + if group_size != 128: + raise ValueError( + f"GPTQ quantization only supports group_size=128, got group_size={group_size}" + ) + reader = cls() extra_model_kwargs = reader.read_config(raw) return reader, extra_model_kwargs @@ -191,11 +208,11 @@ class HFQuantConfigReader(QuantConfigReader): # more features to be added def post_process_model(self, model, model_config): if self._hf_quantizer is None: - return + return model dtype = getattr(model_config, "dtype", None) new_dtype = self._hf_quantizer.update_dtype(dtype) if new_dtype is None: - return + return model model.to(new_dtype) return model diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 5b2902d642..295d17aedc 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -110,11 +110,13 @@ class Quantization(BaseTransform): shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: qcfg = factory.get_quant_config() - if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name: + if not qcfg or ( + qcfg.get("quant_algo", "").upper() != self.algo_name + and qcfg.get("quant_method", "").upper() != self.algo_name + ): return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - excluded = qcfg.get("exclude_modules", []) cnt = 0 for n in gm.graph.nodes: @@ -635,3 +637,104 @@ class NVFP4QuantizationFromGraph(NVFP4LinearQuantizationFromConfig): return gm, TransformInfo( skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True ) + + +@TransformRegistry.register("quantize_int4_gptq_linear_from_config") +class INT4GPTQLinearQuantizationFromConfig(Quantization): + """Config-based INT4 GPTQ quantization for GPTQ-quantized checkpoints. + + GPTQ uses: + - qweight: [K/8, N] int32 (8 packed int4 values per int32) + - qzeros: [G, N/8] int32 (packed zero points) + - scales: [G, N] float (per-group scales) + """ + + algo_name = "GPTQ" + + @staticmethod + def target_op(): + return torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear.default + + @staticmethod + def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: + """Returns placeholder qweight tensor [K/8, N] int32.""" + N, K = original_weight.shape + assert K % 8 == 0, "K must be divisible by 8 for GPTQ int4 packing." + return torch.empty((K // 8, N), dtype=torch.int32, device=original_weight.device) + + @staticmethod + def scale_names() -> List[str]: + return ["scales", "qzeros"] + + @staticmethod + def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: + """Returns placeholder tensors for GPTQ scales and qzeros.""" + N, K = original_weight_shape + BLOCK = 128 # GPTQ group size + assert K % BLOCK == 0, "K must be divisible by 128 for GPTQ block quant." + assert N % 8 == 0, "N must be divisible by 8 for GPTQ qzeros packing." + G = K // BLOCK + return { + "scales": torch.empty((G, N), dtype=torch.float16), + "qzeros": torch.empty((G, N // 8), dtype=torch.int32), + } + + @staticmethod + def build_custom_args_for_linear(scales: Dict[str, Node]) -> Tuple[object, ...]: + """Build args for torch_fake_quant_int4_gptq_linear: + (input, weight, bias, input_scale, weight_scale, input_zp, weight_zp) + -> input_scale=[], weight_scale=[scales], input_zp=[], weight_zp=[qzeros] + """ + return ([], [scales["scales"]], [], [scales["qzeros"]]) + + @staticmethod + def load_hook(state_dict, prefix, *args, weight_name: str): + """ + Load hook for GPTQ checkpoints: + - qweight: keep as [K/8, N] int32 + - scales: [G, N] float16 + - qzeros: [G, N/8] int32 + + GPTQ checkpoint uses naming convention: + - {prefix}qweight + - {prefix}scales + - {prefix}qzeros + """ + + mod_prefix, _, _ = weight_name.rpartition(".") + + qweight_ckpt = f"{mod_prefix}.qweight" + scales_ckpt = f"{mod_prefix}.scales" + qzeros_ckpt = f"{mod_prefix}.qzeros" + + if qweight_ckpt not in state_dict: + return + + qweight = state_dict[qweight_ckpt] + if qweight.dtype != torch.int32: + return + + K_packed, N = qweight.shape # [K/8, N] + K = K_packed * 8 + + assert scales_ckpt in state_dict, f"Missing {scales_ckpt}" + scales = state_dict[scales_ckpt] # [G, N] + G = scales.shape[0] + + assert qzeros_ckpt in state_dict, f"Missing {qzeros_ckpt}" + qzeros = state_dict[qzeros_ckpt] # [G, N/8] + + # Validate GPTQ weight layout + assert K % G == 0, f"K ({K}) must be divisible by G ({G})" + assert scales.shape == (G, N), f"scales shape {scales.shape} != {(G, N)}" + assert qzeros.shape == (G, N // 8), f"qzeros shape {qzeros.shape} != {(G, N // 8)}" + + # Map to our buffer names + state_dict[weight_name] = qweight # [K/8, N] int32 + state_dict[f"{mod_prefix}.scales"] = scales.to(torch.float16) # [G, N] + # GPTQ v1 format stores (zero_point - 1); convert to v2 by adding 0x11111111 + # See: gptqmodel.utils.model.convert_gptq_v1_to_v2_format_module + qzeros_v2 = qzeros + 0x11111111 + state_dict[f"{mod_prefix}.qzeros"] = qzeros_v2 # [G, N/8] int32 (v2 format) + # Remove the original qweight key to avoid "unexpected key" warnings + del state_dict[qweight_ckpt] diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_gptq_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_gptq_op.py new file mode 100644 index 0000000000..1f8e6f5748 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_gptq_op.py @@ -0,0 +1,178 @@ +import pytest +import torch + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 + + +def pack_gptq_qweight_from_u4(U4_nk: torch.Tensor) -> torch.Tensor: + """ + GPTQ: pack along K, 8 nibbles per int32. + U4_nk: [N,K] uint8 in [0..15] + -> qweight: [K/8, N] int32 + """ + assert U4_nk.dtype == torch.uint8 and U4_nk.dim() == 2 + N, K = U4_nk.shape + assert K % 8 == 0 + shifts = torch.arange(8, dtype=torch.int32).view(1, 8, 1) * 4 # [1,8,1] + U4_kn = U4_nk.T.contiguous() # [K, N] u8 + U4_blocks = U4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8,8,N] + qweight = torch.sum(U4_blocks << shifts, dim=1) # [K/8, N] i32 + return qweight + + +def pack_qzeros_all_8(G: int, N: int) -> torch.Tensor: + """ + Build qzeros: [G, N/8] int32 such that each unpacked nibble == 8. + Each int32 holds 8 nibbles; signed int32 value -0x77777778 has the same + bit pattern as 0x88888888 (unsigned). + """ + assert N % 8 == 0 + val = torch.tensor(-0x77777778, dtype=torch.int32) # == -2004318072 + return val.repeat(G, N // 8) # [G, N/8] + + +def pack_uint8_from_Qs_signed(Qs_nk: torch.Tensor) -> torch.Tensor: + """ + ModelOpt: pack along N, 2 nibbles per byte from signed int4 Qs in [-8..7]. + Qs_nk: [N,K] int8 + -> packed: [N/2, K] uint8 (low nibble = even row, high nibble = odd row) + """ + assert Qs_nk.dtype == torch.int8 and Qs_nk.dim() == 2 + N, K = Qs_nk.shape + assert N % 2 == 0 + + # map signed -> nibble (two's complement) + def to_u4(x: torch.Tensor) -> torch.Tensor: + x16 = x.to(torch.int16) + u = torch.where(x16 >= 0, x16, x16 + 16).to(torch.uint8) # [N,K] in 0..15 + return u + + even_u4 = to_u4(Qs_nk[0::2, :]) # [N/2, K] u8 + odd_u4 = to_u4(Qs_nk[1::2, :]) # [N/2, K] u8 + return (even_u4 | (odd_u4 << 4)).contiguous().to(torch.uint8) # [N/2, K] + + +def gptq_unpack_unsigned_u4_KN( + qweight: torch.Tensor, wf_unsqueeze_neg_one: torch.Tensor +) -> torch.Tensor: + """ + Mirror the custom op's unpack (for the weight path): returns unsigned nibbles [K,N] u8. + """ + pack_factor = 8 + w = torch.bitwise_right_shift( + qweight.unsqueeze(1).expand(-1, pack_factor, -1), # [K/8,8,N] + wf_unsqueeze_neg_one.to(qweight.dtype), # [1,8,1] + ).to(torch.int16) + w = (w & 15).to(torch.uint8).reshape(-1, qweight.shape[1]) # [K,N] u8 + return w + + +def modelopt_unpack_Qs_signed_NK(weight_packed: torch.Tensor) -> torch.Tensor: + """ + Unpack ModelOpt packed bytes back to signed int4 in [-8..7], [N,K] int8. + """ + pw = weight_packed.T.contiguous() # [K, N/2] u8 + low = (pw & 0x0F).to(torch.int16) # [K, N/2] + high = ((pw >> 4) & 0x0F).to(torch.int16) # [K, N/2] + low_s = torch.where(low >= 8, low - 16, low).to(torch.int8) # [-8..7] + high_s = torch.where(high >= 8, high - 16, high).to(torch.int8) # [-8..7] + rebuilt = torch.stack([low_s, high_s], dim=-1) # [K, N/2, 2] i8 + int8_T = rebuilt.reshape(pw.shape[0], -1) # [K, N] i8 + return int8_T.T.contiguous() # [N, K] i8 + + +def gptq_weight_and_out(input_x, qweight, qzeros, scales_gn, g_idx, wf_zero, wf_neg1): + """ + Exact math that your custom op implements: + weights = scales[g_idx] * (unpacked_u4 - unpacked_qzeros[g_idx]).to(input.dtype) # [K,N] + out = input @ weights + """ + # unpack unsigned nibbles [K,N] + u4_kn = gptq_unpack_unsigned_u4_KN(qweight, wf_neg1).to(torch.int16) # [K,N] + # unpack qzeros to nibbles [G,N] -> broadcast with g_idx + zeros = ( + torch.bitwise_right_shift( + qzeros.unsqueeze(2).expand(-1, -1, 8), # [G, N/8, 8] + wf_zero.to(qzeros.dtype), # [1,1,8] + ).to(torch.int16) + & 15 + ) + zeros_gn = zeros.reshape(scales_gn.shape) # [G,N] + z_kn = zeros_gn[g_idx.long()] # [K,N] int16 + + scale_kn = scales_gn[g_idx.long()].to(torch.float32) # [K,N] + W_kn = scale_kn * (u4_kn - z_kn).to(torch.float32) # [K,N] + y = input_x.to(torch.float32) @ W_kn + return W_kn, y + + +def modelopt_weight_and_out(input_x, weight_packed, weight_scale_ng): + Qs_nk = modelopt_unpack_Qs_signed_NK(weight_packed).to(torch.float32) # [N,K] + S_nk = weight_scale_ng.repeat_interleave(128, dim=1).to(torch.float32) # [N,K] + W_kn = (Qs_nk * S_nk).T.contiguous() # [K,N] + y = input_x.to(torch.float32) @ W_kn + return W_kn, y + + +@pytest.mark.parametrize("N,K,BLOCK_SIZE", [(896, 4864, 128)]) +def test_gptq_vs_modelopt_qzeros_8_match(N, K, BLOCK_SIZE): + torch.manual_seed(0) + G = K // BLOCK_SIZE + assert K % 8 == 0 and K % BLOCK_SIZE == 0 and N % 2 == 0 + + # Ground-truth signed int4 weights Q_s in [-8..7] + Qs_nk = torch.randint(-8, 8, (N, K), dtype=torch.int8) + + # Convert to codebooks for each path + U4_gptq = (Qs_nk.to(torch.int16) + 8).to(torch.uint8) # [N,K] 0..15 + weight_quantized = pack_uint8_from_Qs_signed(Qs_nk) # [N/2, K] u8 + + # Pack GPTQ qweight and qzeros (all nibbles = 8) + qweight_gptq = pack_gptq_qweight_from_u4(U4_gptq) # [K/8, N] i32 + qzeros = pack_qzeros_all_8(G, N) # [G, N/8] i32 + + # Scales: GPTQ stores [G,N], ModelOpt stores [N,G] (transpose) + scales_gn = torch.rand(G, N, dtype=torch.float32) * 2.0 # [G,N] + weight_scale_ng = scales_gn.T.contiguous() # [N,G] + + # Index & shifts + g_idx = torch.arange(K, dtype=torch.int32) // BLOCK_SIZE # [K] + wf = torch.arange(8, dtype=torch.int32) * 4 + wf_zero = wf.view(1, 1, 8) # [1,1,8] + wf_neg1 = wf.view(1, 8, 1) # [1,8,1] + + x = torch.randn(3, K, dtype=torch.float32) + + Wg, yg = gptq_weight_and_out(x, qweight_gptq, qzeros, scales_gn, g_idx, wf_zero, wf_neg1) + Wm, ym = modelopt_weight_and_out(x, weight_quantized, weight_scale_ng) + + torch.testing.assert_close(Wg, Wm, rtol=0, atol=0) + torch.testing.assert_close(yg, ym, rtol=0, atol=0) + + bias = None + pre_scale = torch.tensor(1.0, dtype=torch.float32) + input_scale_list = [pre_scale] + weight_scale_list = [weight_scale_ng] + input_zp_list, weight_zp_list = [torch.tensor(0)], [torch.tensor(0)] + + y_gptq = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear( + x, + qweight_gptq, + None, # bias + [], # input_scale + [scales_gn], # weight_scale + [], # input_zp + [qzeros], # weight_zp + ) + y_mo = torch.ops.auto_deploy.torch_fake_quant_int4_linear( + x, + weight_quantized, + bias, + input_scale_list, + weight_scale_list, + input_zp_list, + weight_zp_list, + ) + + # small mismatch ≈ 5/2048, likely from the GEMM calculation + torch.testing.assert_close(y_gptq, y_mo, rtol=0, atol=3e-3) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index 38caf1c5dc..ade54ba146 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -251,7 +251,119 @@ def test_int4awq_transform_graph_and_load_hook(): True, # test_load_hook False, # strict_loading None, # dynamic_shapes - None, # check_num_matches + False, # skip_output_assert + quant_config, + ) + + # Still exportable + torch_export_to_gm(gm_transformed, args=(x,)) + + +def _pack_gptq_qweight(weights: torch.Tensor, scales: torch.Tensor) -> torch.Tensor: + """ + Pack float weights to GPTQ qweight format [K/8, N] int32. + Uses GPTQ symmetric quantization: signed int4 [-8,7] stored as unsigned [0,15]. + + weights: [N, K] float + scales: [G, N] float (per-group scales) + Returns: [K/8, N] int32 packed qweight + """ + N, K = weights.shape + G = scales.shape[0] + BLOCK = K // G + assert K % 8 == 0 + + # Expand scales from [G, N] to [N, K] for per-element division + scales_nk = scales.T.repeat_interleave(BLOCK, dim=1) # [N, K] + + # Quantize: w / scale gives float in approx [-8, 7] range, round to int + # GPTQ uses zero_point=8, so signed int4 [-8,7] maps to unsigned [0,15] + q_signed = (weights / scales_nk).round().clamp(-8, 7).to(torch.int8) # [N, K] in [-8,7] + u4_nk = (q_signed.to(torch.int16) + 8).to(torch.uint8) # [N, K] in [0,15] + + # Pack along K: [N, K] -> [K/8, N] int32 + # Use explicit int32 ops to avoid promotion to int64 + u4_kn = u4_nk.T.contiguous() # [K, N] + u4_blocks = u4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8, 8, N] + shifts = [0, 4, 8, 12, 16, 20, 24, 28] + qweight = torch.zeros(K // 8, N, dtype=torch.int32, device=weights.device) + for i in range(8): + qweight = qweight | (u4_blocks[:, i, :] << shifts[i]) + return qweight + + +def _pack_gptq_qzeros_v1(G: int, N: int, device: torch.device) -> torch.Tensor: + """ + Build qzeros [G, N/8] int32 in GPTQ v1 format. + v1 stores (zero_point - 1) = 7 per nibble -> 0x77777777 per int32. + """ + assert N % 8 == 0 + val = torch.tensor(0x77777777, dtype=torch.int32, device=device) + return val.repeat(G, N // 8) + + +def test_int4gptq_transform_graph_and_load_hook(): + """INT4 GPTQ transform with load_hook converting v1 qzeros to v2.""" + device = "cuda" + torch.manual_seed(42) + quant_config = {"quant_algo": "GPTQ"} + BLOCK = 128 # GPTQ group size + QUANT_OP = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear + + # Model with K divisible by 128, N divisible by 8 + model = MLP(256, 128, 256).to(torch.float16).to(device) + x = torch.randn(3, 256, dtype=torch.float16, device=device) + + def gptq_state_dict_hook(module: nn.Module, state_dict: dict, prefix: str, local_meta: dict): + """Convert FP weights to GPTQ checkpoint format (v1 qzeros).""" + for name, m in module.named_modules(): + if not isinstance(m, nn.Linear): + continue + key_w = f"{prefix}{name}.weight" + if key_w not in state_dict: + continue + + W = state_dict[key_w].detach().to(torch.float32).to(device) # [N, K] + N, K = W.shape + assert N % 8 == 0 and K % BLOCK == 0 + + G = K // BLOCK + + # Per-group scales from per-group amax + amax = W.abs().view(N, G, BLOCK).amax(dim=-1) # [N, G] + scales = (amax / 7.0).T.to(torch.float32) # [G, N] float32 for quantization + + # Build GPTQ tensors using symmetric quantization with zero_point=8 + qweight = _pack_gptq_qweight(W, scales) # [K/8, N] int32 + qzeros = _pack_gptq_qzeros_v1(G, N, W.device) # [G, N/8] int32 (v1, zero_point=8) + scales = scales.to(torch.float16) # [G, N] convert to fp16 for storage + + # Replace weight with qweight, add GPTQ-specific keys + state_dict[f"{prefix}{name}.qweight"] = qweight + state_dict[f"{prefix}{name}.scales"] = scales + state_dict[f"{prefix}{name}.qzeros"] = qzeros + # Remove original weight key + del state_dict[key_w] + + model._register_state_dict_hook(gptq_state_dict_hook) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + DummyFactory(quant_config), + {"quantize_int4_gptq_linear_from_config": {"stage": "pattern_matcher"}}, + )(None, gm).to(device) + + run_test_transformed_gm( + model, + x, + gm_transformed, + lambda gm_: any(is_op(n, QUANT_OP) for n in gm_.graph.nodes), + lambda num_p_og: num_p_og // 8, # qweight packs 8 int4 values per int32 + 0.5, # atol (quantization error expected) + 0.5, # rtol + True, # test_load_hook + False, # strict_loading + None, # dynamic_shapes False, # skip_output_assert quant_config, )