[quantization] feat: support aobaseconfig classes in TorchAOConfig (#12275)
* feat: support aobaseconfig classes. * [docs] AOBaseConfig (#12302) init Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * up * replace with is_torchao_version * up * up --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
@@ -14,11 +14,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import importlib.metadata
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from packaging import version
|
||||
from parameterized import parameterized
|
||||
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
||||
|
||||
@@ -65,6 +67,9 @@ if is_torchao_available():
|
||||
from torchao.quantization.quant_primitives import MappingType
|
||||
from torchao.utils import get_model_size_in_bytes
|
||||
|
||||
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.9.0"):
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_torch_accelerator
|
||||
@@ -522,6 +527,15 @@ class TorchAoTest(unittest.TestCase):
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
def test_aobase_config(self):
|
||||
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
|
||||
components = self.get_dummy_components(quantization_config)
|
||||
pipe = FluxPipeline(**components).to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
_ = pipe(**inputs)
|
||||
|
||||
|
||||
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
|
||||
@require_torch
|
||||
@@ -628,6 +642,14 @@ class TorchAoSerializationTest(unittest.TestCase):
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.9.0")
|
||||
def test_aobase_config(self):
|
||||
quant_method, quant_method_kwargs = Int8WeightOnlyConfig(), {}
|
||||
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551])
|
||||
device = torch_device
|
||||
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice)
|
||||
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
|
||||
|
||||
|
||||
@require_torchao_version_greater_or_equal("0.7.0")
|
||||
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user