[Quantization] Quanto quantizer (#29023)

* start integration

* fix

* add and debug tests

* update tests

* make pytorch serialization works

* compatible with device_map and offload

* fix tests

* make style

* add ref

* guard against safetensors

* add float8 and style

* fix is_serializable

* Fix shard_checkpoint compatibility with quanto

* more tests

* docs

* adjust memory

* better

* style

* pass tests

* Update src/transformers/modeling_utils.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add is_safe_serialization instead

* Update src/transformers/quantizers/quantizer_quanto.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add QbitsTensor tests

* fix tests

* simplify activation list

* Update docs/source/en/quantization.md

Co-authored-by: David Corvoysier <david.corvoysier@gmail.com>

* better comment

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: David Corvoysier <david.corvoysier@gmail.com>

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: David Corvoysier <david.corvoysier@gmail.com>

* find and fix edge case

* Update docs/source/en/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* pass weights_only_kwarg instead

* fix shard_checkpoint loading

* simplify update_missing_keys

* Update tests/quantization/quanto_integration/test_quanto.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* recursion to get all tensors

* block serialization

* skip serialization tests

* fix

* change by cuda:0 for now

* fix regression

* update device_map

* fix doc

* add noteboon

* update torch_dtype

* update doc

* typo

* typo

* remove comm

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: David Corvoysier <david.corvoysier@gmail.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Younes Belkada <younesbelkada@gmail.com>
This commit is contained in:
Marc Sun 2024-03-15 11:51:29 -04:00 committed by GitHub
parent f02aea2737
commit 28de2f4de3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 885 additions and 7 deletions

View File

@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
# Add quanto for quantization testing
RUN python3 -m pip install --no-cache-dir quanto
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.
RUN cd transformers && python3 setup.py develop

View File

@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
</Tip>
## QuantoConfig
[[autodoc]] QuantoConfig
## AqlmConfig
[[autodoc]] AqlmConfig

View File

@ -26,6 +26,59 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan
</Tip>
## Quanto
<Tip>
Try Quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!
</Tip>
[🤗 Quanto](https://github.com/huggingface/quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:
- weights quantization (`float8`,`int8`,`int4`,`int2`)
- activation quantization (`float8`,`int8`)
- modality agnostic (e.g CV,LLM)
- device agnostic (e.g CUDA,MPS,CPU)
- compatibility with `torch.compile`
- easy to add custom kernel for specific device
- supports quantization aware training
<!-- Add link to the blogpost -->
Before you begin, make sure the following libraries are installed:
```bash
pip install quanto
pip install git+https://github.com/huggingface/accelerate.git
pip install git+https://github.com/huggingface/transformers.git
```
Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [quanto](https://github.com/huggingface/quanto) library instead.
```py
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
model_id = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(model_id)
quantization_config = QuantoConfig(weights="int8")
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", quantization_config=quantization_config)
```
Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.
Quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following becnhmark (llama-2-7b on perplexity metric). You can find more benchamarks [here](https://github.com/huggingface/quanto/tree/main/bench/generation)
<div class="flex gap-4">
<div>
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/NousResearch-Llama-2-7b-hf_Perplexity.png" alt="llama-2-7b-quanto-perplexity" />
</div>
</div>
The library is versatible enough to be compatible with most PTQ optimization algorithms. The plan in the future is to integrate the most popular algorithms in the most seamless possible way (AWQ, Smoothquant).
## AQLM

View File

@ -1100,7 +1100,7 @@ _import_structure = {
"is_vision_available",
"logging",
],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"],
}
# sentencepiece-backed objects
@ -5921,7 +5921,7 @@ if TYPE_CHECKING:
)
# bitsandbytes config
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig
try:
if not is_sentencepiece_available():

View File

@ -82,6 +82,7 @@ _import_structure = {
"run_hp_search_wandb",
],
"peft": ["PeftAdapterMixin"],
"quanto": ["replace_with_quanto_layers"],
}
if TYPE_CHECKING:
@ -150,6 +151,7 @@ if TYPE_CHECKING:
run_hp_search_wandb,
)
from .peft import PeftAdapterMixin
from .quanto import replace_with_quanto_layers
else:
import sys

View File

@ -0,0 +1,94 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import is_torch_available
if is_torch_available():
import torch
def replace_with_quanto_layers(
model,
quantization_config=None,
modules_to_not_convert=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AqlmConfig`, defaults to `None`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list`, *optional*, defaults to `None`):
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
converted.
current_key_name (`list`, *optional*, defaults to `None`):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*, defaults to `None`):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
from accelerate import init_empty_weights
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
if modules_to_not_convert is None:
modules_to_not_convert = []
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if isinstance(module, torch.nn.Linear):
model._modules[name] = QLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias is not None,
dtype=module.weight.dtype,
weights=w_mapping[quantization_config.weights],
activations=a_mapping[quantization_config.activations],
)
model._modules[name].requires_grad_(False)
has_been_replaced = True
elif isinstance(module, torch.nn.LayerNorm):
if quantization_config.activations is not None:
model._modules[name] = QLayerNorm(
module.normalized_shape,
module.eps,
module.elementwise_affine,
module.bias is not None,
activations=a_mapping[quantization_config.activations],
)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_quanto_layers(
module,
quantization_config=quantization_config,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced

View File

@ -802,7 +802,11 @@ def _load_state_dict_into_meta_model(
elif (
not is_quantized
or (not hf_quantizer.requires_parameters_quantization)
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
or (
not hf_quantizer.check_quantized_param(
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
)
)
):
# For backward compatibility with older versions of `accelerate` and for non-quantized params
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
@ -3728,6 +3732,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
for pat in cls._keys_to_ignore_on_load_unexpected:
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
if hf_quantizer is not None:
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
# retrieve weights on meta device and put them back on CPU.
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
if low_cpu_mem_usage:

View File

@ -22,12 +22,14 @@ from ..utils.quantization_config import (
GPTQConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
)
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
AUTO_QUANTIZER_MAPPING = {
@ -36,6 +38,7 @@ AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
"gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@ -44,6 +47,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_8bit": BitsAndBytesConfig,
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
}

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from ..utils import is_torch_available
from ..utils.quantization_config import QuantizationConfigMixin
@ -99,6 +99,16 @@ class HfQuantizer(ABC):
"""
return torch_dtype
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
"""
Override this method if you want to adjust the `missing_keys`.
Args:
missing_keys (`List[str]`, *optional*):
The list of missing keys in the checkpoint compared to the state dict of the model
"""
return missing_keys
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
"""
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
@ -111,6 +121,7 @@ class HfQuantizer(ABC):
torch_dtype (`torch.dtype`):
The dtype passed in `from_pretrained` method.
"""
return {
name: torch_dtype
for name, _ in model.named_parameters()
@ -122,7 +133,12 @@ class HfQuantizer(ABC):
return max_memory
def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
"""
checks if a loaded state_dict component is part of quantized param + some validation; only defined if

View File

@ -116,7 +116,12 @@ class Bnb4BitHfQuantizer(HfQuantizer):
)
def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
import bitsandbytes as bnb

View File

@ -134,7 +134,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
return torch.int8
def check_quantized_param(
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
import bitsandbytes as bnb

View File

@ -0,0 +1,198 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from packaging import version
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
from ..utils.quantization_config import QuantoConfig
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class QuantoHfQuantizer(HfQuantizer):
"""
Quantizer for the quanto library
"""
required_packages = ["quanto", "accelerate"]
requires_parameters_quantization = True
requires_calibration = False
def __init__(self, quantization_config: QuantoConfig, **kwargs):
super().__init__(quantization_config, **kwargs)
self.post_init()
def post_init(self):
r"""
Safety checker
"""
if self.quantization_config.activations is not None and not self.pre_quantized:
raise ValueError(
"We don't support quantizing the activations with transformers library."
"Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
)
def validate_environment(self, *args, **kwargs):
if not is_quanto_available():
raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
if not is_accelerate_available():
raise ImportError("Loading a quanto quantized model requires accelerate library (`pip install quanto`)")
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": "cpu"}
logger.info(
"The device_map was not initialized. "
"Setting device_map to {'':'cpu'}. "
"If you want to use the model for inference, please set device_map ='auto'"
)
return device_map
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
import quanto
not_missing_keys = []
for name, module in model.named_modules():
if isinstance(module, quanto.QModuleMixin):
for missing in missing_keys:
if (
(name in missing or name in f"{prefix}.{missing}")
and not missing.endswith(".weight")
and not missing.endswith(".bias")
):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
"""
Check if a parameter needs to be quantized.
"""
import quanto
device_map = kwargs.get("device_map", None)
param_device = kwargs.get("param_device", None)
# we don't quantize the model if the module is going to be offloaded to the cpu
if device_map is not None and param_device is not None:
device_map_values = set(device_map.values())
if param_device == "cpu" and len(device_map_values) > 1:
if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}):
return False
module, tensor_name = get_module_from_name(model, param_name)
# We only quantize the weights and the bias is not quantized.
if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
return not module.frozen
else:
return False
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .freeze() after setting it to the module.
"""
from accelerate.utils import set_module_tensor_to_device
set_module_tensor_to_device(model, param_name, target_device, param_value)
module, _ = get_module_from_name(model, param_name)
module.freeze()
module.weight.requires_grad = False
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.27.0"):
from accelerate.utils import CustomDtype
mapping = {
"int8": torch.int8,
"float8": CustomDtype.FP8,
"int4": CustomDtype.INT4,
"int2": CustomDtype.INT2,
}
target_dtype = mapping[self.quantization_config.weights]
return target_dtype
else:
raise ValueError(
"You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
" the appropriate device map, you should upgrade your `accelerate` library,"
"`pip install --upgrade accelerate` or install it from source."
)
def _process_model_before_weight_loading(
self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
):
from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
if self.quantization_config.modules_to_not_convert is None:
self.modules_to_not_convert = get_keys_to_not_convert(model)
else:
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
if not isinstance(self.modules_to_not_convert, list):
self.modules_to_not_convert = [self.modules_to_not_convert]
self.modules_to_not_convert.extend(keep_in_fp32_modules)
model, _ = replace_with_quanto_layers(
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model):
return model
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
@property
def is_serializable(self):
return False

View File

@ -89,6 +89,7 @@ from .utils import (
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,
@ -1043,6 +1044,13 @@ def require_auto_awq(test_case):
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
def require_quanto(test_case):
"""
Decorator for quanto dependency
"""
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer

View File

@ -152,6 +152,7 @@ from .import_utils import (
is_pytesseract_available,
is_pytest_available,
is_pytorch_quantization_available,
is_quanto_available,
is_rjieba_available,
is_sacremoses_available,
is_safetensors_available,

View File

@ -131,6 +131,7 @@ _optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quanto_available = _is_package_available("quanto")
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
@ -788,6 +789,10 @@ def is_auto_awq_available():
return _auto_awq_available
def is_quanto_available():
return _quanto_available
def is_auto_gptq_available():
return _auto_gptq_available

View File

@ -39,6 +39,7 @@ class QuantizationMethod(str, Enum):
GPTQ = "gptq"
AWQ = "awq"
AQLM = "aqlm"
QUANTO = "quanto"
class AWQLinearVersion(str, Enum):
@ -836,3 +837,44 @@ class AqlmConfig(QuantizationConfigMixin):
if self.linear_weights_not_to_quantize is None:
self.linear_weights_not_to_quantize = []
@dataclass
class QuantoConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `quanto`.
Args:
weights (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
activations (`str`, *optional*):
The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
"""
def __init__(
self,
weights="int8",
activations=None,
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.QUANTO
self.weights = weights
self.activations = activations
self.modules_to_not_convert = modules_to_not_convert
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["float8", "int8", "int4", "int2"]
accepted_activations = [None, "int8", "float8"]
if self.weights not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
if self.activations not in accepted_activations:
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")

View File

@ -0,0 +1,431 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
if is_quanto_available():
from quanto import QLayerNorm, QLinear
from transformers.integrations.quanto import replace_with_quanto_layers
class QuantoConfigTest(unittest.TestCase):
def test_attributes(self):
pass
@require_quanto
@require_accelerate
class QuantoTestIntegration(unittest.TestCase):
model_id = "facebook/opt-350m"
def setUp(self):
config = AutoConfig.from_pretrained(self.model_id)
with init_empty_weights():
self.model = AutoModelForCausalLM.from_config(config)
self.nb_linear = 0
self.nb_layernorm = 0
for module in self.model.modules():
if isinstance(module, torch.nn.Linear):
self.nb_linear += 1
elif isinstance(module, torch.nn.LayerNorm):
self.nb_layernorm += 1
def test_weight_only_quantization_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly when using weight only quantization
"""
# Try with weight only quantization
quantization_config = QuantoConfig(weights="int8", activations=None)
self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config)
nb_qlinear = 0
for module in self.model.modules():
if isinstance(module, QLinear):
nb_qlinear += 1
self.assertEqual(self.nb_linear, nb_qlinear)
def test_weight_and_activation_quantization_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly when using weight + activation quantization
"""
# Try with weight + activation quantization
quantization_config = QuantoConfig(weights="int8", activations="int8")
self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config)
nb_qlinear = 0
nb_qlayernorm = 0
for module in self.model.modules():
if isinstance(module, QLinear):
nb_qlinear += 1
if isinstance(module, QLayerNorm):
nb_qlayernorm += 1
self.assertEqual(self.nb_linear, nb_qlinear)
self.assertEqual(self.nb_layernorm, nb_qlayernorm)
def test_conversion_with_modules_to_not_convert(self):
"""
Simple test that checks if the quantized model has been converted properly when specifying modules_to_not_convert argument
"""
# Try with weight + activatioin quantization
quantization_config = QuantoConfig(weights="int8", activations="int8")
self.model, _ = replace_with_quanto_layers(
self.model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
)
nb_qlinear = 0
nb_qlayernorm = 0
for module in self.model.modules():
if isinstance(module, QLinear):
nb_qlinear += 1
if isinstance(module, QLayerNorm):
nb_qlayernorm += 1
self.assertEqual(self.nb_linear - 1, nb_qlinear)
@slow
@require_torch_gpu
@require_quanto
@require_accelerate
class QuantoQuantizationTest(unittest.TestCase):
"""
Test 8-bit weights only quantization
"""
model_name = "bigscience/bloom-560m"
weights = "int8"
activations = None
device_map = "cpu"
input_text = "Hello my name is"
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer and I"
def setUp(self):
"""
Setup quantized model
"""
quantization_config = QuantoConfig(
weights=self.weights,
activations=self.activations,
)
self.quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
quantization_config=quantization_config,
torch_dtype=torch.float32,
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.have_accelerate_hooks = (
getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1
)
def check_inference_correctness(self, model, device):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
if not self.have_accelerate_hooks:
model.to(device)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(device), max_new_tokens=10)
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality_cpu(self):
"""
Simple test to check the quality of the model on cpu by comparing the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model, "cpu")
def test_generate_quality_cuda(self):
"""
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model, "cuda")
def test_quantized_model_layers(self):
from quanto import QBitsTensor, QModuleMixin, QTensor
"""
Suite of simple test to check if the layers are quantized and are working properly
"""
# Test the type of the quantized layer
self.assertTrue(isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value, QModuleMixin))
self.assertTrue(
isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QTensor)
)
if self.weights == "int4":
self.assertTrue(
isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QBitsTensor)
)
# check that the lm_head was indeed not quantized, just like bnb
self.assertTrue(
isinstance(self.quantized_model.lm_head, torch.nn.Linear)
and not isinstance(self.quantized_model.lm_head, QModuleMixin)
)
if self.device_map in ["cpu", "cuda"]:
self.assertEqual(
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type,
self.device_map,
)
self.quantized_model.to(0)
self.assertEqual(
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
)
def test_serialization_bin(self):
"""
Test the serialization, the loading and the inference of the quantized weights
"""
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(ValueError) as e:
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception))
# TODO: replace by the following when it works
# quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
# tmpdirname, torch_dtype=torch.float32, device_map="cpu"
# )
# self.check_inference_correctness(quantized_model_from_saved, device="cuda")
def test_serialization_safetensors(self):
"""
Test the serialization, the loading and the inference of the quantized weights
"""
with tempfile.TemporaryDirectory() as tmpdirname:
with self.assertRaises(ValueError) as e:
self.quantized_model.save_pretrained(tmpdirname)
self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception))
# quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
# tmpdirname, torch_dtype=torch.float32, device_map="cpu"
# )
# self.check_inference_correctness(quantized_model_from_saved, device="cuda")
def check_same_model(self, model1, model2):
d0 = dict(model1.named_parameters())
d1 = dict(model2.named_parameters())
self.assertTrue(d0.keys() == d1.keys())
for k in d0.keys():
self.assertTrue(d0[k].shape == d1[k].shape)
self.assertTrue(d0[k].device.type == d1[k].device.type)
self.assertTrue(d0[k].device == d1[k].device)
self.assertTrue(d0[k].dtype == d1[k].dtype)
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
def test_compare_with_quanto(self):
from quanto import freeze, qint4, qint8, quantize
w_mapping = {"int8": qint8, "int4": qint4}
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
torch_dtype=torch.float32,
)
# we do not quantize the lm_head since we don't do that in transformers
quantize(model.transformer, weights=w_mapping[self.weights])
freeze(model.transformer)
self.check_same_model(model, self.quantized_model)
self.check_inference_correctness(model, device="cuda")
@unittest.skip
def test_load_from_quanto_saved(self):
from quanto import freeze, qint4, qint8, quantize
from transformers import QuantoConfig
w_mapping = {"int8": qint8, "int4": qint4}
model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
torch_dtype=torch.float32,
)
# we do not quantize the lm_head since we don't do that in transformers
quantize(model.transformer, weights=w_mapping[self.weights])
freeze(model.transformer)
with tempfile.TemporaryDirectory() as tmpdirname:
model.config.quantization_config = QuantoConfig(
weights=self.weights, activations=self.activations, modules_to_not_convert=["lm_head"]
)
model.save_pretrained(tmpdirname, safe_serialization=False)
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
tmpdirname,
device_map=self.device_map,
torch_dtype=torch.float32,
)
self.check_same_model(model, quantized_model_from_saved)
self.check_inference_correctness(quantized_model_from_saved, device="cuda")
class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
device_map = {
"transformer.word_embeddings": 0,
"transformer.word_embeddings_layernorm": 0,
"transformer.ln_f": 0,
"transformer.h.0": 0,
"transformer.h.1": 0,
"transformer.h.2": 0,
"transformer.h.3": 0,
"transformer.h.4": 0,
"transformer.h.5": 0,
"transformer.h.6": 0,
"transformer.h.7": 0,
"transformer.h.8": 0,
"transformer.h.9": 0,
"transformer.h.10": 0,
"transformer.h.11": 0,
"transformer.h.12": 0,
"transformer.h.13": 0,
"transformer.h.14": 0,
"transformer.h.15": 0,
"transformer.h.16": 0,
"transformer.h.17": 0,
"transformer.h.18": 0,
"transformer.h.19": 0,
"transformer.h.20": 0,
"transformer.h.21": 0,
"transformer.h.22": "cpu",
"transformer.h.23": "disk",
"lm_head": 0,
}
# the execution device is a gpu
def test_generate_quality_cpu(self):
pass
# we can't save offloaded values
def test_serialization_bin(self):
pass
def test_serialization_safetensors(self):
pass
def test_compare_with_quanto(self):
pass
def test_load_from_quanto_saved(self):
pass
def test_check_offload_quantized(self):
"""
We check that we have unquantized value in the cpu and in the disk
"""
import quanto
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
"weight"
]
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
"weight"
]
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor))
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor))
if self.weights == "int4":
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor))
self.assertTrue(
isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)
)
@unittest.skip("Skipping test class because serialization is not supported yet")
class QuantoQuantizationSerializationTest(QuantoQuantizationTest):
"""
Perform the same tests as in QuantoQuantizationTest but with a serialized model.
"""
def setUp(self):
"""
Setup quantized model
"""
quantization_config = QuantoConfig(
weights=self.weights,
activations=self.activations,
)
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
quantization_config=quantization_config,
torch_dtype=torch.float32,
)
with tempfile.TemporaryDirectory() as tmpdirname:
quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
self.quantized_model = AutoModelForCausalLM.from_pretrained(
tmpdirname, torch_dtype=torch.float32, device_map=self.device_map
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.have_accelerate_hooks = (
getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1
)
@unittest.skip("Skipping test class because serialization is not supported yet")
class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest):
"""
Perform the same tests as in QuantoQuantizationTest but with model on cuda
"""
device_map = "cuda:0"
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
weights = "int4"
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
weights = "int4"
@unittest.skip("Skipping test class because serialization is not supported yet")
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
weights = "int4"
@require_torch_gpu
class QuantoQuantizationActivationTest(unittest.TestCase):
def test_quantize_activation(self):
quantization_config = QuantoConfig(
weights="int8",
activations="int8",
)
with self.assertRaises(ValueError) as e:
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))