[#8242][feat] Add int4 GPTQ support for AutoDeploy (#8248)

Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
Frida Hou 2026-01-30 23:07:24 -08:00 committed by GitHub
parent 6bace84167
commit 7910d4d2a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 516 additions and 7 deletions

View File

@ -5,6 +5,8 @@ models:
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
- name: Qwen/Qwen2.5-0.5B-Instruct
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
- name: Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
- name: Qwen/Qwen3-0.6B
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
# DISABLED: TorchDynamo compilation error - fake tensor dispatch failure

View File

@ -63,6 +63,8 @@ transforms:
stage: pattern_matcher
quantize_int4_linear_from_config:
stage: pattern_matcher
quantize_int4_gptq_linear_from_config:
stage: pattern_matcher
quantize_fp8_linear_from_config:
stage: pattern_matcher
quantize_nvfp4_linear_from_config:

View File

@ -319,3 +319,98 @@ def _fake(
N_half = weight_quantized.shape[-2]
N = N_half * 2
return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device)
@torch.library.custom_op("auto_deploy::torch_fake_quant_int4_gptq_linear", mutates_args=())
def torch_fake_quant_int4_gptq_linear(
input: torch.Tensor, # [..., K]
weight_quantized: torch.Tensor, # qweight [K/8, N] int32 (packed)
bias: Optional[torch.Tensor], # [N] or None
input_scale: List[torch.Tensor], # unused for GPTQ
weight_scale: List[torch.Tensor], # GPTQ scales [G, N]
input_zp: List[torch.Tensor], # unused for GPTQ
weight_zp: List[torch.Tensor], # GPTQ qzeros [G, N/8] int32
) -> torch.Tensor:
"""
GPTQ INT4 linear with compatible signature to other quant ops.
- weight_quantized: qweight [K/8, N] packed int32
- weight_scale[0]: scales [G, N]
- weight_zp[0]: qzeros [G, N/8] packed int32
"""
PACK_FACTOR = 8
MAXQ = 15
dequant_dtype = torch.int8
qweight = weight_quantized
scales = _expect_single_scale(weight_scale, "weight_scale")
qzeros = _expect_single_scale(weight_zp, "weight_zp")
dev = qweight.device
input_shape = input.shape
in_features = input_shape[-1]
if qweight.dim() != 2:
raise RuntimeError("qweight must be 2D [K/8, N]")
K = qweight.size(0) * PACK_FACTOR
N = qweight.size(1)
if scales.dim() != 2 or scales.size(1) != N:
raise RuntimeError(f"scales must be [G, N={N}]")
G = scales.size(0)
if K % G != 0:
raise RuntimeError(f"K ({K}) must be divisible by G ({G})")
block_size = K // G
if qzeros.dim() != 2 or qzeros.size(0) != G or qzeros.size(1) * PACK_FACTOR != N:
raise RuntimeError(f"qzeros must be [G={G}, N/8={N // 8}]")
# Reshape input to 2D if needed
x_2d = input.reshape(-1, in_features)
# Build g_idx and shift tables
g_idx = torch.arange(K, device=dev, dtype=torch.int32) // block_size # [K]
wf = torch.arange(PACK_FACTOR, device=dev, dtype=torch.int32) * 4 # [8]
wf_unsqueeze_zero = wf.view(1, 1, PACK_FACTOR) # [1,1,8]
wf_unsqueeze_neg_one = wf.view(1, PACK_FACTOR, 1) # [1,8,1]
zeros = torch.bitwise_right_shift(
torch.unsqueeze(qzeros, 2).expand(-1, -1, PACK_FACTOR),
wf_unsqueeze_zero,
).to(dequant_dtype)
zeros = torch.bitwise_and(zeros, MAXQ).reshape(scales.shape)
weight = torch.bitwise_and(
torch.bitwise_right_shift(
torch.unsqueeze(qweight, 1).expand(-1, PACK_FACTOR, -1),
wf_unsqueeze_neg_one,
).to(dequant_dtype),
MAXQ,
)
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
weights = (scales[g_idx.long()] * (weight - zeros[g_idx.long()])).to(input.dtype)
out = torch.matmul(x_2d, weights)
if bias is not None:
out = out + bias
# Reshape output back to match input batch dimensions
out = out.reshape(*input_shape[:-1], N)
return out
@torch_fake_quant_int4_gptq_linear.register_fake
def torch_fake_quant_int4_gptq_linear_fake(
input: torch.Tensor,
weight_quantized: torch.Tensor,
bias: Optional[torch.Tensor],
input_scale: List[torch.Tensor],
weight_scale: List[torch.Tensor],
input_zp: List[torch.Tensor],
weight_zp: List[torch.Tensor],
) -> torch.Tensor:
N = weight_quantized.size(1)
return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device)

View File

@ -146,6 +146,8 @@ class HFQuantConfigReader(QuantConfigReader):
Quantization reader that process transformers.quantizers.HFQuantizer functionality
"""
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens")
def __init__(self):
super().__init__()
self._hf_quantizer = None
@ -156,6 +158,10 @@ class HFQuantConfigReader(QuantConfigReader):
if not qconf:
raise ValueError("HF quantization_config not found.")
# Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case
excludes = qconf.get("exclude_modules", [])
qconf["exclude_modules"] = excludes + [n for n in self._ALWAYS_EXCLUDE if n not in excludes]
self._quant_config = qconf
from transformers.quantizers import AutoHfQuantizer
@ -177,12 +183,23 @@ class HFQuantConfigReader(QuantConfigReader):
if not isinstance(qconf, dict):
return None
# TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4, other hf quantizers
# TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4 and INT4-GPTQ, other hf quantizers
# should have similar workflow and will be added to the pipeline
quant_method = str(qconf.get("quant_method", "")).lower()
if quant_method != "mxfp4":
if quant_method not in ["mxfp4", "gptq"]:
return None
# Validate GPTQ config: currently only INT4 with group_size=128 is supported
if quant_method == "gptq":
bits = qconf.get("bits")
group_size = qconf.get("group_size")
if bits != 4:
raise ValueError(f"GPTQ quantization only supports bits=4, got bits={bits}")
if group_size != 128:
raise ValueError(
f"GPTQ quantization only supports group_size=128, got group_size={group_size}"
)
reader = cls()
extra_model_kwargs = reader.read_config(raw)
return reader, extra_model_kwargs
@ -191,11 +208,11 @@ class HFQuantConfigReader(QuantConfigReader):
# more features to be added
def post_process_model(self, model, model_config):
if self._hf_quantizer is None:
return
return model
dtype = getattr(model_config, "dtype", None)
new_dtype = self._hf_quantizer.update_dtype(dtype)
if new_dtype is None:
return
return model
model.to(new_dtype)
return model

View File

@ -110,11 +110,13 @@ class Quantization(BaseTransform):
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
qcfg = factory.get_quant_config()
if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name:
if not qcfg or (
qcfg.get("quant_algo", "").upper() != self.algo_name
and qcfg.get("quant_method", "").upper() != self.algo_name
):
return gm, TransformInfo(
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
)
excluded = qcfg.get("exclude_modules", [])
cnt = 0
for n in gm.graph.nodes:
@ -635,3 +637,104 @@ class NVFP4QuantizationFromGraph(NVFP4LinearQuantizationFromConfig):
return gm, TransformInfo(
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True
)
@TransformRegistry.register("quantize_int4_gptq_linear_from_config")
class INT4GPTQLinearQuantizationFromConfig(Quantization):
"""Config-based INT4 GPTQ quantization for GPTQ-quantized checkpoints.
GPTQ uses:
- qweight: [K/8, N] int32 (8 packed int4 values per int32)
- qzeros: [G, N/8] int32 (packed zero points)
- scales: [G, N] float (per-group scales)
"""
algo_name = "GPTQ"
@staticmethod
def target_op():
return torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear.default
@staticmethod
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
"""Returns placeholder qweight tensor [K/8, N] int32."""
N, K = original_weight.shape
assert K % 8 == 0, "K must be divisible by 8 for GPTQ int4 packing."
return torch.empty((K // 8, N), dtype=torch.int32, device=original_weight.device)
@staticmethod
def scale_names() -> List[str]:
return ["scales", "qzeros"]
@staticmethod
def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
"""Returns placeholder tensors for GPTQ scales and qzeros."""
N, K = original_weight_shape
BLOCK = 128 # GPTQ group size
assert K % BLOCK == 0, "K must be divisible by 128 for GPTQ block quant."
assert N % 8 == 0, "N must be divisible by 8 for GPTQ qzeros packing."
G = K // BLOCK
return {
"scales": torch.empty((G, N), dtype=torch.float16),
"qzeros": torch.empty((G, N // 8), dtype=torch.int32),
}
@staticmethod
def build_custom_args_for_linear(scales: Dict[str, Node]) -> Tuple[object, ...]:
"""Build args for torch_fake_quant_int4_gptq_linear:
(input, weight, bias, input_scale, weight_scale, input_zp, weight_zp)
-> input_scale=[], weight_scale=[scales], input_zp=[], weight_zp=[qzeros]
"""
return ([], [scales["scales"]], [], [scales["qzeros"]])
@staticmethod
def load_hook(state_dict, prefix, *args, weight_name: str):
"""
Load hook for GPTQ checkpoints:
- qweight: keep as [K/8, N] int32
- scales: [G, N] float16
- qzeros: [G, N/8] int32
GPTQ checkpoint uses naming convention:
- {prefix}qweight
- {prefix}scales
- {prefix}qzeros
"""
mod_prefix, _, _ = weight_name.rpartition(".")
qweight_ckpt = f"{mod_prefix}.qweight"
scales_ckpt = f"{mod_prefix}.scales"
qzeros_ckpt = f"{mod_prefix}.qzeros"
if qweight_ckpt not in state_dict:
return
qweight = state_dict[qweight_ckpt]
if qweight.dtype != torch.int32:
return
K_packed, N = qweight.shape # [K/8, N]
K = K_packed * 8
assert scales_ckpt in state_dict, f"Missing {scales_ckpt}"
scales = state_dict[scales_ckpt] # [G, N]
G = scales.shape[0]
assert qzeros_ckpt in state_dict, f"Missing {qzeros_ckpt}"
qzeros = state_dict[qzeros_ckpt] # [G, N/8]
# Validate GPTQ weight layout
assert K % G == 0, f"K ({K}) must be divisible by G ({G})"
assert scales.shape == (G, N), f"scales shape {scales.shape} != {(G, N)}"
assert qzeros.shape == (G, N // 8), f"qzeros shape {qzeros.shape} != {(G, N // 8)}"
# Map to our buffer names
state_dict[weight_name] = qweight # [K/8, N] int32
state_dict[f"{mod_prefix}.scales"] = scales.to(torch.float16) # [G, N]
# GPTQ v1 format stores (zero_point - 1); convert to v2 by adding 0x11111111
# See: gptqmodel.utils.model.convert_gptq_v1_to_v2_format_module
qzeros_v2 = qzeros + 0x11111111
state_dict[f"{mod_prefix}.qzeros"] = qzeros_v2 # [G, N/8] int32 (v2 format)
# Remove the original qweight key to avoid "unexpected key" warnings
del state_dict[qweight_ckpt]

View File

@ -0,0 +1,178 @@
import pytest
import torch
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
def pack_gptq_qweight_from_u4(U4_nk: torch.Tensor) -> torch.Tensor:
"""
GPTQ: pack along K, 8 nibbles per int32.
U4_nk: [N,K] uint8 in [0..15]
-> qweight: [K/8, N] int32
"""
assert U4_nk.dtype == torch.uint8 and U4_nk.dim() == 2
N, K = U4_nk.shape
assert K % 8 == 0
shifts = torch.arange(8, dtype=torch.int32).view(1, 8, 1) * 4 # [1,8,1]
U4_kn = U4_nk.T.contiguous() # [K, N] u8
U4_blocks = U4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8,8,N]
qweight = torch.sum(U4_blocks << shifts, dim=1) # [K/8, N] i32
return qweight
def pack_qzeros_all_8(G: int, N: int) -> torch.Tensor:
"""
Build qzeros: [G, N/8] int32 such that each unpacked nibble == 8.
Each int32 holds 8 nibbles; signed int32 value -0x77777778 has the same
bit pattern as 0x88888888 (unsigned).
"""
assert N % 8 == 0
val = torch.tensor(-0x77777778, dtype=torch.int32) # == -2004318072
return val.repeat(G, N // 8) # [G, N/8]
def pack_uint8_from_Qs_signed(Qs_nk: torch.Tensor) -> torch.Tensor:
"""
ModelOpt: pack along N, 2 nibbles per byte from signed int4 Qs in [-8..7].
Qs_nk: [N,K] int8
-> packed: [N/2, K] uint8 (low nibble = even row, high nibble = odd row)
"""
assert Qs_nk.dtype == torch.int8 and Qs_nk.dim() == 2
N, K = Qs_nk.shape
assert N % 2 == 0
# map signed -> nibble (two's complement)
def to_u4(x: torch.Tensor) -> torch.Tensor:
x16 = x.to(torch.int16)
u = torch.where(x16 >= 0, x16, x16 + 16).to(torch.uint8) # [N,K] in 0..15
return u
even_u4 = to_u4(Qs_nk[0::2, :]) # [N/2, K] u8
odd_u4 = to_u4(Qs_nk[1::2, :]) # [N/2, K] u8
return (even_u4 | (odd_u4 << 4)).contiguous().to(torch.uint8) # [N/2, K]
def gptq_unpack_unsigned_u4_KN(
qweight: torch.Tensor, wf_unsqueeze_neg_one: torch.Tensor
) -> torch.Tensor:
"""
Mirror the custom op's unpack (for the weight path): returns unsigned nibbles [K,N] u8.
"""
pack_factor = 8
w = torch.bitwise_right_shift(
qweight.unsqueeze(1).expand(-1, pack_factor, -1), # [K/8,8,N]
wf_unsqueeze_neg_one.to(qweight.dtype), # [1,8,1]
).to(torch.int16)
w = (w & 15).to(torch.uint8).reshape(-1, qweight.shape[1]) # [K,N] u8
return w
def modelopt_unpack_Qs_signed_NK(weight_packed: torch.Tensor) -> torch.Tensor:
"""
Unpack ModelOpt packed bytes back to signed int4 in [-8..7], [N,K] int8.
"""
pw = weight_packed.T.contiguous() # [K, N/2] u8
low = (pw & 0x0F).to(torch.int16) # [K, N/2]
high = ((pw >> 4) & 0x0F).to(torch.int16) # [K, N/2]
low_s = torch.where(low >= 8, low - 16, low).to(torch.int8) # [-8..7]
high_s = torch.where(high >= 8, high - 16, high).to(torch.int8) # [-8..7]
rebuilt = torch.stack([low_s, high_s], dim=-1) # [K, N/2, 2] i8
int8_T = rebuilt.reshape(pw.shape[0], -1) # [K, N] i8
return int8_T.T.contiguous() # [N, K] i8
def gptq_weight_and_out(input_x, qweight, qzeros, scales_gn, g_idx, wf_zero, wf_neg1):
"""
Exact math that your custom op implements:
weights = scales[g_idx] * (unpacked_u4 - unpacked_qzeros[g_idx]).to(input.dtype) # [K,N]
out = input @ weights
"""
# unpack unsigned nibbles [K,N]
u4_kn = gptq_unpack_unsigned_u4_KN(qweight, wf_neg1).to(torch.int16) # [K,N]
# unpack qzeros to nibbles [G,N] -> broadcast with g_idx
zeros = (
torch.bitwise_right_shift(
qzeros.unsqueeze(2).expand(-1, -1, 8), # [G, N/8, 8]
wf_zero.to(qzeros.dtype), # [1,1,8]
).to(torch.int16)
& 15
)
zeros_gn = zeros.reshape(scales_gn.shape) # [G,N]
z_kn = zeros_gn[g_idx.long()] # [K,N] int16
scale_kn = scales_gn[g_idx.long()].to(torch.float32) # [K,N]
W_kn = scale_kn * (u4_kn - z_kn).to(torch.float32) # [K,N]
y = input_x.to(torch.float32) @ W_kn
return W_kn, y
def modelopt_weight_and_out(input_x, weight_packed, weight_scale_ng):
Qs_nk = modelopt_unpack_Qs_signed_NK(weight_packed).to(torch.float32) # [N,K]
S_nk = weight_scale_ng.repeat_interleave(128, dim=1).to(torch.float32) # [N,K]
W_kn = (Qs_nk * S_nk).T.contiguous() # [K,N]
y = input_x.to(torch.float32) @ W_kn
return W_kn, y
@pytest.mark.parametrize("N,K,BLOCK_SIZE", [(896, 4864, 128)])
def test_gptq_vs_modelopt_qzeros_8_match(N, K, BLOCK_SIZE):
torch.manual_seed(0)
G = K // BLOCK_SIZE
assert K % 8 == 0 and K % BLOCK_SIZE == 0 and N % 2 == 0
# Ground-truth signed int4 weights Q_s in [-8..7]
Qs_nk = torch.randint(-8, 8, (N, K), dtype=torch.int8)
# Convert to codebooks for each path
U4_gptq = (Qs_nk.to(torch.int16) + 8).to(torch.uint8) # [N,K] 0..15
weight_quantized = pack_uint8_from_Qs_signed(Qs_nk) # [N/2, K] u8
# Pack GPTQ qweight and qzeros (all nibbles = 8)
qweight_gptq = pack_gptq_qweight_from_u4(U4_gptq) # [K/8, N] i32
qzeros = pack_qzeros_all_8(G, N) # [G, N/8] i32
# Scales: GPTQ stores [G,N], ModelOpt stores [N,G] (transpose)
scales_gn = torch.rand(G, N, dtype=torch.float32) * 2.0 # [G,N]
weight_scale_ng = scales_gn.T.contiguous() # [N,G]
# Index & shifts
g_idx = torch.arange(K, dtype=torch.int32) // BLOCK_SIZE # [K]
wf = torch.arange(8, dtype=torch.int32) * 4
wf_zero = wf.view(1, 1, 8) # [1,1,8]
wf_neg1 = wf.view(1, 8, 1) # [1,8,1]
x = torch.randn(3, K, dtype=torch.float32)
Wg, yg = gptq_weight_and_out(x, qweight_gptq, qzeros, scales_gn, g_idx, wf_zero, wf_neg1)
Wm, ym = modelopt_weight_and_out(x, weight_quantized, weight_scale_ng)
torch.testing.assert_close(Wg, Wm, rtol=0, atol=0)
torch.testing.assert_close(yg, ym, rtol=0, atol=0)
bias = None
pre_scale = torch.tensor(1.0, dtype=torch.float32)
input_scale_list = [pre_scale]
weight_scale_list = [weight_scale_ng]
input_zp_list, weight_zp_list = [torch.tensor(0)], [torch.tensor(0)]
y_gptq = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear(
x,
qweight_gptq,
None, # bias
[], # input_scale
[scales_gn], # weight_scale
[], # input_zp
[qzeros], # weight_zp
)
y_mo = torch.ops.auto_deploy.torch_fake_quant_int4_linear(
x,
weight_quantized,
bias,
input_scale_list,
weight_scale_list,
input_zp_list,
weight_zp_list,
)
# small mismatch ≈ 5/2048, likely from the GEMM calculation
torch.testing.assert_close(y_gptq, y_mo, rtol=0, atol=3e-3)

View File

@ -251,7 +251,119 @@ def test_int4awq_transform_graph_and_load_hook():
True, # test_load_hook
False, # strict_loading
None, # dynamic_shapes
None, # check_num_matches
False, # skip_output_assert
quant_config,
)
# Still exportable
torch_export_to_gm(gm_transformed, args=(x,))
def _pack_gptq_qweight(weights: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
"""
Pack float weights to GPTQ qweight format [K/8, N] int32.
Uses GPTQ symmetric quantization: signed int4 [-8,7] stored as unsigned [0,15].
weights: [N, K] float
scales: [G, N] float (per-group scales)
Returns: [K/8, N] int32 packed qweight
"""
N, K = weights.shape
G = scales.shape[0]
BLOCK = K // G
assert K % 8 == 0
# Expand scales from [G, N] to [N, K] for per-element division
scales_nk = scales.T.repeat_interleave(BLOCK, dim=1) # [N, K]
# Quantize: w / scale gives float in approx [-8, 7] range, round to int
# GPTQ uses zero_point=8, so signed int4 [-8,7] maps to unsigned [0,15]
q_signed = (weights / scales_nk).round().clamp(-8, 7).to(torch.int8) # [N, K] in [-8,7]
u4_nk = (q_signed.to(torch.int16) + 8).to(torch.uint8) # [N, K] in [0,15]
# Pack along K: [N, K] -> [K/8, N] int32
# Use explicit int32 ops to avoid promotion to int64
u4_kn = u4_nk.T.contiguous() # [K, N]
u4_blocks = u4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8, 8, N]
shifts = [0, 4, 8, 12, 16, 20, 24, 28]
qweight = torch.zeros(K // 8, N, dtype=torch.int32, device=weights.device)
for i in range(8):
qweight = qweight | (u4_blocks[:, i, :] << shifts[i])
return qweight
def _pack_gptq_qzeros_v1(G: int, N: int, device: torch.device) -> torch.Tensor:
"""
Build qzeros [G, N/8] int32 in GPTQ v1 format.
v1 stores (zero_point - 1) = 7 per nibble -> 0x77777777 per int32.
"""
assert N % 8 == 0
val = torch.tensor(0x77777777, dtype=torch.int32, device=device)
return val.repeat(G, N // 8)
def test_int4gptq_transform_graph_and_load_hook():
"""INT4 GPTQ transform with load_hook converting v1 qzeros to v2."""
device = "cuda"
torch.manual_seed(42)
quant_config = {"quant_algo": "GPTQ"}
BLOCK = 128 # GPTQ group size
QUANT_OP = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear
# Model with K divisible by 128, N divisible by 8
model = MLP(256, 128, 256).to(torch.float16).to(device)
x = torch.randn(3, 256, dtype=torch.float16, device=device)
def gptq_state_dict_hook(module: nn.Module, state_dict: dict, prefix: str, local_meta: dict):
"""Convert FP weights to GPTQ checkpoint format (v1 qzeros)."""
for name, m in module.named_modules():
if not isinstance(m, nn.Linear):
continue
key_w = f"{prefix}{name}.weight"
if key_w not in state_dict:
continue
W = state_dict[key_w].detach().to(torch.float32).to(device) # [N, K]
N, K = W.shape
assert N % 8 == 0 and K % BLOCK == 0
G = K // BLOCK
# Per-group scales from per-group amax
amax = W.abs().view(N, G, BLOCK).amax(dim=-1) # [N, G]
scales = (amax / 7.0).T.to(torch.float32) # [G, N] float32 for quantization
# Build GPTQ tensors using symmetric quantization with zero_point=8
qweight = _pack_gptq_qweight(W, scales) # [K/8, N] int32
qzeros = _pack_gptq_qzeros_v1(G, N, W.device) # [G, N/8] int32 (v1, zero_point=8)
scales = scales.to(torch.float16) # [G, N] convert to fp16 for storage
# Replace weight with qweight, add GPTQ-specific keys
state_dict[f"{prefix}{name}.qweight"] = qweight
state_dict[f"{prefix}{name}.scales"] = scales
state_dict[f"{prefix}{name}.qzeros"] = qzeros
# Remove original weight key
del state_dict[key_w]
model._register_state_dict_hook(gptq_state_dict_hook)
gm = torch_export_to_gm(model, args=(x,), clone=True)
gm_transformed = InferenceOptimizer(
DummyFactory(quant_config),
{"quantize_int4_gptq_linear_from_config": {"stage": "pattern_matcher"}},
)(None, gm).to(device)
run_test_transformed_gm(
model,
x,
gm_transformed,
lambda gm_: any(is_op(n, QUANT_OP) for n in gm_.graph.nodes),
lambda num_p_og: num_p_og // 8, # qweight packs 8 int4 values per int32
0.5, # atol (quantization error expected)
0.5, # rtol
True, # test_load_hook
False, # strict_loading
None, # dynamic_shapes
False, # skip_output_assert
quant_config,
)