mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-05 02:31:33 +08:00
Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com>
This commit is contained in:
parent
6bace84167
commit
7910d4d2a9
@ -5,6 +5,8 @@ models:
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
|
||||
- name: Qwen/Qwen2.5-0.5B-Instruct
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
|
||||
- name: Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
|
||||
- name: Qwen/Qwen3-0.6B
|
||||
yaml_extra: ['dashboard_default.yaml', 'world_size_1.yaml']
|
||||
# DISABLED: TorchDynamo compilation error - fake tensor dispatch failure
|
||||
|
||||
@ -63,6 +63,8 @@ transforms:
|
||||
stage: pattern_matcher
|
||||
quantize_int4_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_int4_gptq_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_fp8_linear_from_config:
|
||||
stage: pattern_matcher
|
||||
quantize_nvfp4_linear_from_config:
|
||||
|
||||
@ -319,3 +319,98 @@ def _fake(
|
||||
N_half = weight_quantized.shape[-2]
|
||||
N = N_half * 2
|
||||
return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device)
|
||||
|
||||
|
||||
@torch.library.custom_op("auto_deploy::torch_fake_quant_int4_gptq_linear", mutates_args=())
|
||||
def torch_fake_quant_int4_gptq_linear(
|
||||
input: torch.Tensor, # [..., K]
|
||||
weight_quantized: torch.Tensor, # qweight [K/8, N] int32 (packed)
|
||||
bias: Optional[torch.Tensor], # [N] or None
|
||||
input_scale: List[torch.Tensor], # unused for GPTQ
|
||||
weight_scale: List[torch.Tensor], # GPTQ scales [G, N]
|
||||
input_zp: List[torch.Tensor], # unused for GPTQ
|
||||
weight_zp: List[torch.Tensor], # GPTQ qzeros [G, N/8] int32
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
GPTQ INT4 linear with compatible signature to other quant ops.
|
||||
- weight_quantized: qweight [K/8, N] packed int32
|
||||
- weight_scale[0]: scales [G, N]
|
||||
- weight_zp[0]: qzeros [G, N/8] packed int32
|
||||
"""
|
||||
PACK_FACTOR = 8
|
||||
MAXQ = 15
|
||||
dequant_dtype = torch.int8
|
||||
|
||||
qweight = weight_quantized
|
||||
scales = _expect_single_scale(weight_scale, "weight_scale")
|
||||
qzeros = _expect_single_scale(weight_zp, "weight_zp")
|
||||
|
||||
dev = qweight.device
|
||||
input_shape = input.shape
|
||||
in_features = input_shape[-1]
|
||||
|
||||
if qweight.dim() != 2:
|
||||
raise RuntimeError("qweight must be 2D [K/8, N]")
|
||||
K = qweight.size(0) * PACK_FACTOR
|
||||
N = qweight.size(1)
|
||||
|
||||
if scales.dim() != 2 or scales.size(1) != N:
|
||||
raise RuntimeError(f"scales must be [G, N={N}]")
|
||||
G = scales.size(0)
|
||||
|
||||
if K % G != 0:
|
||||
raise RuntimeError(f"K ({K}) must be divisible by G ({G})")
|
||||
block_size = K // G
|
||||
|
||||
if qzeros.dim() != 2 or qzeros.size(0) != G or qzeros.size(1) * PACK_FACTOR != N:
|
||||
raise RuntimeError(f"qzeros must be [G={G}, N/8={N // 8}]")
|
||||
|
||||
# Reshape input to 2D if needed
|
||||
x_2d = input.reshape(-1, in_features)
|
||||
|
||||
# Build g_idx and shift tables
|
||||
g_idx = torch.arange(K, device=dev, dtype=torch.int32) // block_size # [K]
|
||||
wf = torch.arange(PACK_FACTOR, device=dev, dtype=torch.int32) * 4 # [8]
|
||||
wf_unsqueeze_zero = wf.view(1, 1, PACK_FACTOR) # [1,1,8]
|
||||
wf_unsqueeze_neg_one = wf.view(1, PACK_FACTOR, 1) # [1,8,1]
|
||||
|
||||
zeros = torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qzeros, 2).expand(-1, -1, PACK_FACTOR),
|
||||
wf_unsqueeze_zero,
|
||||
).to(dequant_dtype)
|
||||
zeros = torch.bitwise_and(zeros, MAXQ).reshape(scales.shape)
|
||||
|
||||
weight = torch.bitwise_and(
|
||||
torch.bitwise_right_shift(
|
||||
torch.unsqueeze(qweight, 1).expand(-1, PACK_FACTOR, -1),
|
||||
wf_unsqueeze_neg_one,
|
||||
).to(dequant_dtype),
|
||||
MAXQ,
|
||||
)
|
||||
weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2])
|
||||
|
||||
weights = (scales[g_idx.long()] * (weight - zeros[g_idx.long()])).to(input.dtype)
|
||||
|
||||
out = torch.matmul(x_2d, weights)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
||||
# Reshape output back to match input batch dimensions
|
||||
out = out.reshape(*input_shape[:-1], N)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch_fake_quant_int4_gptq_linear.register_fake
|
||||
def torch_fake_quant_int4_gptq_linear_fake(
|
||||
input: torch.Tensor,
|
||||
weight_quantized: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
input_scale: List[torch.Tensor],
|
||||
weight_scale: List[torch.Tensor],
|
||||
input_zp: List[torch.Tensor],
|
||||
weight_zp: List[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
N = weight_quantized.size(1)
|
||||
return torch.empty((*input.shape[:-1], N), dtype=input.dtype, device=input.device)
|
||||
|
||||
@ -146,6 +146,8 @@ class HFQuantConfigReader(QuantConfigReader):
|
||||
Quantization reader that process transformers.quantizers.HFQuantizer functionality
|
||||
"""
|
||||
|
||||
_ALWAYS_EXCLUDE = ("lm_head", "model.embed_tokens")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._hf_quantizer = None
|
||||
@ -156,6 +158,10 @@ class HFQuantConfigReader(QuantConfigReader):
|
||||
if not qconf:
|
||||
raise ValueError("HF quantization_config not found.")
|
||||
|
||||
# Inject default exclusion, add "model.embed_tokens" for "tie_word_embedding:true" case
|
||||
excludes = qconf.get("exclude_modules", [])
|
||||
qconf["exclude_modules"] = excludes + [n for n in self._ALWAYS_EXCLUDE if n not in excludes]
|
||||
|
||||
self._quant_config = qconf
|
||||
from transformers.quantizers import AutoHfQuantizer
|
||||
|
||||
@ -177,12 +183,23 @@ class HFQuantConfigReader(QuantConfigReader):
|
||||
if not isinstance(qconf, dict):
|
||||
return None
|
||||
|
||||
# TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4, other hf quantizers
|
||||
# TODO(Fridah-nv):this class is only verified with GPT-OSS MXFP4 and INT4-GPTQ, other hf quantizers
|
||||
# should have similar workflow and will be added to the pipeline
|
||||
quant_method = str(qconf.get("quant_method", "")).lower()
|
||||
if quant_method != "mxfp4":
|
||||
if quant_method not in ["mxfp4", "gptq"]:
|
||||
return None
|
||||
|
||||
# Validate GPTQ config: currently only INT4 with group_size=128 is supported
|
||||
if quant_method == "gptq":
|
||||
bits = qconf.get("bits")
|
||||
group_size = qconf.get("group_size")
|
||||
if bits != 4:
|
||||
raise ValueError(f"GPTQ quantization only supports bits=4, got bits={bits}")
|
||||
if group_size != 128:
|
||||
raise ValueError(
|
||||
f"GPTQ quantization only supports group_size=128, got group_size={group_size}"
|
||||
)
|
||||
|
||||
reader = cls()
|
||||
extra_model_kwargs = reader.read_config(raw)
|
||||
return reader, extra_model_kwargs
|
||||
@ -191,11 +208,11 @@ class HFQuantConfigReader(QuantConfigReader):
|
||||
# more features to be added
|
||||
def post_process_model(self, model, model_config):
|
||||
if self._hf_quantizer is None:
|
||||
return
|
||||
return model
|
||||
dtype = getattr(model_config, "dtype", None)
|
||||
new_dtype = self._hf_quantizer.update_dtype(dtype)
|
||||
if new_dtype is None:
|
||||
return
|
||||
return model
|
||||
model.to(new_dtype)
|
||||
return model
|
||||
|
||||
|
||||
@ -110,11 +110,13 @@ class Quantization(BaseTransform):
|
||||
shared_config: SharedConfig,
|
||||
) -> Tuple[GraphModule, TransformInfo]:
|
||||
qcfg = factory.get_quant_config()
|
||||
if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name:
|
||||
if not qcfg or (
|
||||
qcfg.get("quant_algo", "").upper() != self.algo_name
|
||||
and qcfg.get("quant_method", "").upper() != self.algo_name
|
||||
):
|
||||
return gm, TransformInfo(
|
||||
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
|
||||
)
|
||||
|
||||
excluded = qcfg.get("exclude_modules", [])
|
||||
cnt = 0
|
||||
for n in gm.graph.nodes:
|
||||
@ -635,3 +637,104 @@ class NVFP4QuantizationFromGraph(NVFP4LinearQuantizationFromConfig):
|
||||
return gm, TransformInfo(
|
||||
skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=True
|
||||
)
|
||||
|
||||
|
||||
@TransformRegistry.register("quantize_int4_gptq_linear_from_config")
|
||||
class INT4GPTQLinearQuantizationFromConfig(Quantization):
|
||||
"""Config-based INT4 GPTQ quantization for GPTQ-quantized checkpoints.
|
||||
|
||||
GPTQ uses:
|
||||
- qweight: [K/8, N] int32 (8 packed int4 values per int32)
|
||||
- qzeros: [G, N/8] int32 (packed zero points)
|
||||
- scales: [G, N] float (per-group scales)
|
||||
"""
|
||||
|
||||
algo_name = "GPTQ"
|
||||
|
||||
@staticmethod
|
||||
def target_op():
|
||||
return torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear.default
|
||||
|
||||
@staticmethod
|
||||
def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor:
|
||||
"""Returns placeholder qweight tensor [K/8, N] int32."""
|
||||
N, K = original_weight.shape
|
||||
assert K % 8 == 0, "K must be divisible by 8 for GPTQ int4 packing."
|
||||
return torch.empty((K // 8, N), dtype=torch.int32, device=original_weight.device)
|
||||
|
||||
@staticmethod
|
||||
def scale_names() -> List[str]:
|
||||
return ["scales", "qzeros"]
|
||||
|
||||
@staticmethod
|
||||
def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]:
|
||||
"""Returns placeholder tensors for GPTQ scales and qzeros."""
|
||||
N, K = original_weight_shape
|
||||
BLOCK = 128 # GPTQ group size
|
||||
assert K % BLOCK == 0, "K must be divisible by 128 for GPTQ block quant."
|
||||
assert N % 8 == 0, "N must be divisible by 8 for GPTQ qzeros packing."
|
||||
G = K // BLOCK
|
||||
return {
|
||||
"scales": torch.empty((G, N), dtype=torch.float16),
|
||||
"qzeros": torch.empty((G, N // 8), dtype=torch.int32),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def build_custom_args_for_linear(scales: Dict[str, Node]) -> Tuple[object, ...]:
|
||||
"""Build args for torch_fake_quant_int4_gptq_linear:
|
||||
(input, weight, bias, input_scale, weight_scale, input_zp, weight_zp)
|
||||
-> input_scale=[], weight_scale=[scales], input_zp=[], weight_zp=[qzeros]
|
||||
"""
|
||||
return ([], [scales["scales"]], [], [scales["qzeros"]])
|
||||
|
||||
@staticmethod
|
||||
def load_hook(state_dict, prefix, *args, weight_name: str):
|
||||
"""
|
||||
Load hook for GPTQ checkpoints:
|
||||
- qweight: keep as [K/8, N] int32
|
||||
- scales: [G, N] float16
|
||||
- qzeros: [G, N/8] int32
|
||||
|
||||
GPTQ checkpoint uses naming convention:
|
||||
- {prefix}qweight
|
||||
- {prefix}scales
|
||||
- {prefix}qzeros
|
||||
"""
|
||||
|
||||
mod_prefix, _, _ = weight_name.rpartition(".")
|
||||
|
||||
qweight_ckpt = f"{mod_prefix}.qweight"
|
||||
scales_ckpt = f"{mod_prefix}.scales"
|
||||
qzeros_ckpt = f"{mod_prefix}.qzeros"
|
||||
|
||||
if qweight_ckpt not in state_dict:
|
||||
return
|
||||
|
||||
qweight = state_dict[qweight_ckpt]
|
||||
if qweight.dtype != torch.int32:
|
||||
return
|
||||
|
||||
K_packed, N = qweight.shape # [K/8, N]
|
||||
K = K_packed * 8
|
||||
|
||||
assert scales_ckpt in state_dict, f"Missing {scales_ckpt}"
|
||||
scales = state_dict[scales_ckpt] # [G, N]
|
||||
G = scales.shape[0]
|
||||
|
||||
assert qzeros_ckpt in state_dict, f"Missing {qzeros_ckpt}"
|
||||
qzeros = state_dict[qzeros_ckpt] # [G, N/8]
|
||||
|
||||
# Validate GPTQ weight layout
|
||||
assert K % G == 0, f"K ({K}) must be divisible by G ({G})"
|
||||
assert scales.shape == (G, N), f"scales shape {scales.shape} != {(G, N)}"
|
||||
assert qzeros.shape == (G, N // 8), f"qzeros shape {qzeros.shape} != {(G, N // 8)}"
|
||||
|
||||
# Map to our buffer names
|
||||
state_dict[weight_name] = qweight # [K/8, N] int32
|
||||
state_dict[f"{mod_prefix}.scales"] = scales.to(torch.float16) # [G, N]
|
||||
# GPTQ v1 format stores (zero_point - 1); convert to v2 by adding 0x11111111
|
||||
# See: gptqmodel.utils.model.convert_gptq_v1_to_v2_format_module
|
||||
qzeros_v2 = qzeros + 0x11111111
|
||||
state_dict[f"{mod_prefix}.qzeros"] = qzeros_v2 # [G, N/8] int32 (v2 format)
|
||||
# Remove the original qweight key to avoid "unexpected key" warnings
|
||||
del state_dict[qweight_ckpt]
|
||||
|
||||
@ -0,0 +1,178 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401
|
||||
|
||||
|
||||
def pack_gptq_qweight_from_u4(U4_nk: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
GPTQ: pack along K, 8 nibbles per int32.
|
||||
U4_nk: [N,K] uint8 in [0..15]
|
||||
-> qweight: [K/8, N] int32
|
||||
"""
|
||||
assert U4_nk.dtype == torch.uint8 and U4_nk.dim() == 2
|
||||
N, K = U4_nk.shape
|
||||
assert K % 8 == 0
|
||||
shifts = torch.arange(8, dtype=torch.int32).view(1, 8, 1) * 4 # [1,8,1]
|
||||
U4_kn = U4_nk.T.contiguous() # [K, N] u8
|
||||
U4_blocks = U4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8,8,N]
|
||||
qweight = torch.sum(U4_blocks << shifts, dim=1) # [K/8, N] i32
|
||||
return qweight
|
||||
|
||||
|
||||
def pack_qzeros_all_8(G: int, N: int) -> torch.Tensor:
|
||||
"""
|
||||
Build qzeros: [G, N/8] int32 such that each unpacked nibble == 8.
|
||||
Each int32 holds 8 nibbles; signed int32 value -0x77777778 has the same
|
||||
bit pattern as 0x88888888 (unsigned).
|
||||
"""
|
||||
assert N % 8 == 0
|
||||
val = torch.tensor(-0x77777778, dtype=torch.int32) # == -2004318072
|
||||
return val.repeat(G, N // 8) # [G, N/8]
|
||||
|
||||
|
||||
def pack_uint8_from_Qs_signed(Qs_nk: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
ModelOpt: pack along N, 2 nibbles per byte from signed int4 Qs in [-8..7].
|
||||
Qs_nk: [N,K] int8
|
||||
-> packed: [N/2, K] uint8 (low nibble = even row, high nibble = odd row)
|
||||
"""
|
||||
assert Qs_nk.dtype == torch.int8 and Qs_nk.dim() == 2
|
||||
N, K = Qs_nk.shape
|
||||
assert N % 2 == 0
|
||||
|
||||
# map signed -> nibble (two's complement)
|
||||
def to_u4(x: torch.Tensor) -> torch.Tensor:
|
||||
x16 = x.to(torch.int16)
|
||||
u = torch.where(x16 >= 0, x16, x16 + 16).to(torch.uint8) # [N,K] in 0..15
|
||||
return u
|
||||
|
||||
even_u4 = to_u4(Qs_nk[0::2, :]) # [N/2, K] u8
|
||||
odd_u4 = to_u4(Qs_nk[1::2, :]) # [N/2, K] u8
|
||||
return (even_u4 | (odd_u4 << 4)).contiguous().to(torch.uint8) # [N/2, K]
|
||||
|
||||
|
||||
def gptq_unpack_unsigned_u4_KN(
|
||||
qweight: torch.Tensor, wf_unsqueeze_neg_one: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Mirror the custom op's unpack (for the weight path): returns unsigned nibbles [K,N] u8.
|
||||
"""
|
||||
pack_factor = 8
|
||||
w = torch.bitwise_right_shift(
|
||||
qweight.unsqueeze(1).expand(-1, pack_factor, -1), # [K/8,8,N]
|
||||
wf_unsqueeze_neg_one.to(qweight.dtype), # [1,8,1]
|
||||
).to(torch.int16)
|
||||
w = (w & 15).to(torch.uint8).reshape(-1, qweight.shape[1]) # [K,N] u8
|
||||
return w
|
||||
|
||||
|
||||
def modelopt_unpack_Qs_signed_NK(weight_packed: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Unpack ModelOpt packed bytes back to signed int4 in [-8..7], [N,K] int8.
|
||||
"""
|
||||
pw = weight_packed.T.contiguous() # [K, N/2] u8
|
||||
low = (pw & 0x0F).to(torch.int16) # [K, N/2]
|
||||
high = ((pw >> 4) & 0x0F).to(torch.int16) # [K, N/2]
|
||||
low_s = torch.where(low >= 8, low - 16, low).to(torch.int8) # [-8..7]
|
||||
high_s = torch.where(high >= 8, high - 16, high).to(torch.int8) # [-8..7]
|
||||
rebuilt = torch.stack([low_s, high_s], dim=-1) # [K, N/2, 2] i8
|
||||
int8_T = rebuilt.reshape(pw.shape[0], -1) # [K, N] i8
|
||||
return int8_T.T.contiguous() # [N, K] i8
|
||||
|
||||
|
||||
def gptq_weight_and_out(input_x, qweight, qzeros, scales_gn, g_idx, wf_zero, wf_neg1):
|
||||
"""
|
||||
Exact math that your custom op implements:
|
||||
weights = scales[g_idx] * (unpacked_u4 - unpacked_qzeros[g_idx]).to(input.dtype) # [K,N]
|
||||
out = input @ weights
|
||||
"""
|
||||
# unpack unsigned nibbles [K,N]
|
||||
u4_kn = gptq_unpack_unsigned_u4_KN(qweight, wf_neg1).to(torch.int16) # [K,N]
|
||||
# unpack qzeros to nibbles [G,N] -> broadcast with g_idx
|
||||
zeros = (
|
||||
torch.bitwise_right_shift(
|
||||
qzeros.unsqueeze(2).expand(-1, -1, 8), # [G, N/8, 8]
|
||||
wf_zero.to(qzeros.dtype), # [1,1,8]
|
||||
).to(torch.int16)
|
||||
& 15
|
||||
)
|
||||
zeros_gn = zeros.reshape(scales_gn.shape) # [G,N]
|
||||
z_kn = zeros_gn[g_idx.long()] # [K,N] int16
|
||||
|
||||
scale_kn = scales_gn[g_idx.long()].to(torch.float32) # [K,N]
|
||||
W_kn = scale_kn * (u4_kn - z_kn).to(torch.float32) # [K,N]
|
||||
y = input_x.to(torch.float32) @ W_kn
|
||||
return W_kn, y
|
||||
|
||||
|
||||
def modelopt_weight_and_out(input_x, weight_packed, weight_scale_ng):
|
||||
Qs_nk = modelopt_unpack_Qs_signed_NK(weight_packed).to(torch.float32) # [N,K]
|
||||
S_nk = weight_scale_ng.repeat_interleave(128, dim=1).to(torch.float32) # [N,K]
|
||||
W_kn = (Qs_nk * S_nk).T.contiguous() # [K,N]
|
||||
y = input_x.to(torch.float32) @ W_kn
|
||||
return W_kn, y
|
||||
|
||||
|
||||
@pytest.mark.parametrize("N,K,BLOCK_SIZE", [(896, 4864, 128)])
|
||||
def test_gptq_vs_modelopt_qzeros_8_match(N, K, BLOCK_SIZE):
|
||||
torch.manual_seed(0)
|
||||
G = K // BLOCK_SIZE
|
||||
assert K % 8 == 0 and K % BLOCK_SIZE == 0 and N % 2 == 0
|
||||
|
||||
# Ground-truth signed int4 weights Q_s in [-8..7]
|
||||
Qs_nk = torch.randint(-8, 8, (N, K), dtype=torch.int8)
|
||||
|
||||
# Convert to codebooks for each path
|
||||
U4_gptq = (Qs_nk.to(torch.int16) + 8).to(torch.uint8) # [N,K] 0..15
|
||||
weight_quantized = pack_uint8_from_Qs_signed(Qs_nk) # [N/2, K] u8
|
||||
|
||||
# Pack GPTQ qweight and qzeros (all nibbles = 8)
|
||||
qweight_gptq = pack_gptq_qweight_from_u4(U4_gptq) # [K/8, N] i32
|
||||
qzeros = pack_qzeros_all_8(G, N) # [G, N/8] i32
|
||||
|
||||
# Scales: GPTQ stores [G,N], ModelOpt stores [N,G] (transpose)
|
||||
scales_gn = torch.rand(G, N, dtype=torch.float32) * 2.0 # [G,N]
|
||||
weight_scale_ng = scales_gn.T.contiguous() # [N,G]
|
||||
|
||||
# Index & shifts
|
||||
g_idx = torch.arange(K, dtype=torch.int32) // BLOCK_SIZE # [K]
|
||||
wf = torch.arange(8, dtype=torch.int32) * 4
|
||||
wf_zero = wf.view(1, 1, 8) # [1,1,8]
|
||||
wf_neg1 = wf.view(1, 8, 1) # [1,8,1]
|
||||
|
||||
x = torch.randn(3, K, dtype=torch.float32)
|
||||
|
||||
Wg, yg = gptq_weight_and_out(x, qweight_gptq, qzeros, scales_gn, g_idx, wf_zero, wf_neg1)
|
||||
Wm, ym = modelopt_weight_and_out(x, weight_quantized, weight_scale_ng)
|
||||
|
||||
torch.testing.assert_close(Wg, Wm, rtol=0, atol=0)
|
||||
torch.testing.assert_close(yg, ym, rtol=0, atol=0)
|
||||
|
||||
bias = None
|
||||
pre_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
input_scale_list = [pre_scale]
|
||||
weight_scale_list = [weight_scale_ng]
|
||||
input_zp_list, weight_zp_list = [torch.tensor(0)], [torch.tensor(0)]
|
||||
|
||||
y_gptq = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear(
|
||||
x,
|
||||
qweight_gptq,
|
||||
None, # bias
|
||||
[], # input_scale
|
||||
[scales_gn], # weight_scale
|
||||
[], # input_zp
|
||||
[qzeros], # weight_zp
|
||||
)
|
||||
y_mo = torch.ops.auto_deploy.torch_fake_quant_int4_linear(
|
||||
x,
|
||||
weight_quantized,
|
||||
bias,
|
||||
input_scale_list,
|
||||
weight_scale_list,
|
||||
input_zp_list,
|
||||
weight_zp_list,
|
||||
)
|
||||
|
||||
# small mismatch ≈ 5/2048, likely from the GEMM calculation
|
||||
torch.testing.assert_close(y_gptq, y_mo, rtol=0, atol=3e-3)
|
||||
@ -251,7 +251,119 @@ def test_int4awq_transform_graph_and_load_hook():
|
||||
True, # test_load_hook
|
||||
False, # strict_loading
|
||||
None, # dynamic_shapes
|
||||
None, # check_num_matches
|
||||
False, # skip_output_assert
|
||||
quant_config,
|
||||
)
|
||||
|
||||
# Still exportable
|
||||
torch_export_to_gm(gm_transformed, args=(x,))
|
||||
|
||||
|
||||
def _pack_gptq_qweight(weights: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Pack float weights to GPTQ qweight format [K/8, N] int32.
|
||||
Uses GPTQ symmetric quantization: signed int4 [-8,7] stored as unsigned [0,15].
|
||||
|
||||
weights: [N, K] float
|
||||
scales: [G, N] float (per-group scales)
|
||||
Returns: [K/8, N] int32 packed qweight
|
||||
"""
|
||||
N, K = weights.shape
|
||||
G = scales.shape[0]
|
||||
BLOCK = K // G
|
||||
assert K % 8 == 0
|
||||
|
||||
# Expand scales from [G, N] to [N, K] for per-element division
|
||||
scales_nk = scales.T.repeat_interleave(BLOCK, dim=1) # [N, K]
|
||||
|
||||
# Quantize: w / scale gives float in approx [-8, 7] range, round to int
|
||||
# GPTQ uses zero_point=8, so signed int4 [-8,7] maps to unsigned [0,15]
|
||||
q_signed = (weights / scales_nk).round().clamp(-8, 7).to(torch.int8) # [N, K] in [-8,7]
|
||||
u4_nk = (q_signed.to(torch.int16) + 8).to(torch.uint8) # [N, K] in [0,15]
|
||||
|
||||
# Pack along K: [N, K] -> [K/8, N] int32
|
||||
# Use explicit int32 ops to avoid promotion to int64
|
||||
u4_kn = u4_nk.T.contiguous() # [K, N]
|
||||
u4_blocks = u4_kn.view(K // 8, 8, N).to(torch.int32) # [K/8, 8, N]
|
||||
shifts = [0, 4, 8, 12, 16, 20, 24, 28]
|
||||
qweight = torch.zeros(K // 8, N, dtype=torch.int32, device=weights.device)
|
||||
for i in range(8):
|
||||
qweight = qweight | (u4_blocks[:, i, :] << shifts[i])
|
||||
return qweight
|
||||
|
||||
|
||||
def _pack_gptq_qzeros_v1(G: int, N: int, device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Build qzeros [G, N/8] int32 in GPTQ v1 format.
|
||||
v1 stores (zero_point - 1) = 7 per nibble -> 0x77777777 per int32.
|
||||
"""
|
||||
assert N % 8 == 0
|
||||
val = torch.tensor(0x77777777, dtype=torch.int32, device=device)
|
||||
return val.repeat(G, N // 8)
|
||||
|
||||
|
||||
def test_int4gptq_transform_graph_and_load_hook():
|
||||
"""INT4 GPTQ transform with load_hook converting v1 qzeros to v2."""
|
||||
device = "cuda"
|
||||
torch.manual_seed(42)
|
||||
quant_config = {"quant_algo": "GPTQ"}
|
||||
BLOCK = 128 # GPTQ group size
|
||||
QUANT_OP = torch.ops.auto_deploy.torch_fake_quant_int4_gptq_linear
|
||||
|
||||
# Model with K divisible by 128, N divisible by 8
|
||||
model = MLP(256, 128, 256).to(torch.float16).to(device)
|
||||
x = torch.randn(3, 256, dtype=torch.float16, device=device)
|
||||
|
||||
def gptq_state_dict_hook(module: nn.Module, state_dict: dict, prefix: str, local_meta: dict):
|
||||
"""Convert FP weights to GPTQ checkpoint format (v1 qzeros)."""
|
||||
for name, m in module.named_modules():
|
||||
if not isinstance(m, nn.Linear):
|
||||
continue
|
||||
key_w = f"{prefix}{name}.weight"
|
||||
if key_w not in state_dict:
|
||||
continue
|
||||
|
||||
W = state_dict[key_w].detach().to(torch.float32).to(device) # [N, K]
|
||||
N, K = W.shape
|
||||
assert N % 8 == 0 and K % BLOCK == 0
|
||||
|
||||
G = K // BLOCK
|
||||
|
||||
# Per-group scales from per-group amax
|
||||
amax = W.abs().view(N, G, BLOCK).amax(dim=-1) # [N, G]
|
||||
scales = (amax / 7.0).T.to(torch.float32) # [G, N] float32 for quantization
|
||||
|
||||
# Build GPTQ tensors using symmetric quantization with zero_point=8
|
||||
qweight = _pack_gptq_qweight(W, scales) # [K/8, N] int32
|
||||
qzeros = _pack_gptq_qzeros_v1(G, N, W.device) # [G, N/8] int32 (v1, zero_point=8)
|
||||
scales = scales.to(torch.float16) # [G, N] convert to fp16 for storage
|
||||
|
||||
# Replace weight with qweight, add GPTQ-specific keys
|
||||
state_dict[f"{prefix}{name}.qweight"] = qweight
|
||||
state_dict[f"{prefix}{name}.scales"] = scales
|
||||
state_dict[f"{prefix}{name}.qzeros"] = qzeros
|
||||
# Remove original weight key
|
||||
del state_dict[key_w]
|
||||
|
||||
model._register_state_dict_hook(gptq_state_dict_hook)
|
||||
|
||||
gm = torch_export_to_gm(model, args=(x,), clone=True)
|
||||
gm_transformed = InferenceOptimizer(
|
||||
DummyFactory(quant_config),
|
||||
{"quantize_int4_gptq_linear_from_config": {"stage": "pattern_matcher"}},
|
||||
)(None, gm).to(device)
|
||||
|
||||
run_test_transformed_gm(
|
||||
model,
|
||||
x,
|
||||
gm_transformed,
|
||||
lambda gm_: any(is_op(n, QUANT_OP) for n in gm_.graph.nodes),
|
||||
lambda num_p_og: num_p_og // 8, # qweight packs 8 int4 values per int32
|
||||
0.5, # atol (quantization error expected)
|
||||
0.5, # rtol
|
||||
True, # test_load_hook
|
||||
False, # strict_loading
|
||||
None, # dynamic_shapes
|
||||
False, # skip_output_assert
|
||||
quant_config,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user