[Quantization] Quanto quantizer (#29023)
* start integration * fix * add and debug tests * update tests * make pytorch serialization works * compatible with device_map and offload * fix tests * make style * add ref * guard against safetensors * add float8 and style * fix is_serializable * Fix shard_checkpoint compatibility with quanto * more tests * docs * adjust memory * better * style * pass tests * Update src/transformers/modeling_utils.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * add is_safe_serialization instead * Update src/transformers/quantizers/quantizer_quanto.py Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * add QbitsTensor tests * fix tests * simplify activation list * Update docs/source/en/quantization.md Co-authored-by: David Corvoysier <david.corvoysier@gmail.com> * better comment * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: David Corvoysier <david.corvoysier@gmail.com> * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: David Corvoysier <david.corvoysier@gmail.com> * find and fix edge case * Update docs/source/en/quantization.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * pass weights_only_kwarg instead * fix shard_checkpoint loading * simplify update_missing_keys * Update tests/quantization/quanto_integration/test_quanto.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * recursion to get all tensors * block serialization * skip serialization tests * fix * change by cuda:0 for now * fix regression * update device_map * fix doc * add noteboon * update torch_dtype * update doc * typo * typo * remove comm --------- Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: David Corvoysier <david.corvoysier@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Younes Belkada <younesbelkada@gmail.com>
This commit is contained in:
parent
f02aea2737
commit
28de2f4de3
|
@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
|
|||
# Add autoawq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
|
||||
|
||||
# Add quanto for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir quanto
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
RUN cd transformers && python3 setup.py develop
|
|
@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
|||
|
||||
</Tip>
|
||||
|
||||
## QuantoConfig
|
||||
|
||||
[[autodoc]] QuantoConfig
|
||||
|
||||
## AqlmConfig
|
||||
|
||||
[[autodoc]] AqlmConfig
|
||||
|
|
|
@ -26,6 +26,59 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan
|
|||
|
||||
</Tip>
|
||||
|
||||
## Quanto
|
||||
|
||||
<Tip>
|
||||
|
||||
Try Quanto + transformers with this [notebook](https://colab.research.google.com/drive/16CXfVmtdQvciSh9BopZUDYcmXCDpvgrT?usp=sharing)!
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
[🤗 Quanto](https://github.com/huggingface/quanto) library is a versatile pytorch quantization toolkit. The quantization method used is the linear quantization. Quanto provides several unique features such as:
|
||||
|
||||
- weights quantization (`float8`,`int8`,`int4`,`int2`)
|
||||
- activation quantization (`float8`,`int8`)
|
||||
- modality agnostic (e.g CV,LLM)
|
||||
- device agnostic (e.g CUDA,MPS,CPU)
|
||||
- compatibility with `torch.compile`
|
||||
- easy to add custom kernel for specific device
|
||||
- supports quantization aware training
|
||||
<!-- Add link to the blogpost -->
|
||||
|
||||
Before you begin, make sure the following libraries are installed:
|
||||
|
||||
```bash
|
||||
pip install quanto
|
||||
pip install git+https://github.com/huggingface/accelerate.git
|
||||
pip install git+https://github.com/huggingface/transformers.git
|
||||
```
|
||||
|
||||
Now you can quantize a model by passing [`QuantoConfig`] object in the [`~PreTrainedModel.from_pretrained`] method. This works for any model in any modality, as long as it contains `torch.nn.Linear` layers.
|
||||
|
||||
The integration with transformers only supports weights quantization. For the more complex use case such as activation quantization, calibration and quantization aware training, you should use [quanto](https://github.com/huggingface/quanto) library instead.
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||
|
||||
model_id = "facebook/opt-125m"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
quantization_config = QuantoConfig(weights="int8")
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0", quantization_config=quantization_config)
|
||||
```
|
||||
|
||||
Note that serialization is not supported yet with transformers but it is coming soon! If you want to save the model, you can use quanto library instead.
|
||||
|
||||
Quanto library uses linear quantization algorithm for quantization. Even though this is a basic quantization technique, we get very good results! Have a look at the following becnhmark (llama-2-7b on perplexity metric). You can find more benchamarks [here](https://github.com/huggingface/quanto/tree/main/bench/generation)
|
||||
|
||||
<div class="flex gap-4">
|
||||
<div>
|
||||
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/NousResearch-Llama-2-7b-hf_Perplexity.png" alt="llama-2-7b-quanto-perplexity" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
The library is versatible enough to be compatible with most PTQ optimization algorithms. The plan in the future is to integrate the most popular algorithms in the most seamless possible way (AWQ, Smoothquant).
|
||||
|
||||
## AQLM
|
||||
|
||||
|
||||
|
|
|
@ -1100,7 +1100,7 @@ _import_structure = {
|
|||
"is_vision_available",
|
||||
"logging",
|
||||
],
|
||||
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
|
||||
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig", "QuantoConfig"],
|
||||
}
|
||||
|
||||
# sentencepiece-backed objects
|
||||
|
@ -5921,7 +5921,7 @@ if TYPE_CHECKING:
|
|||
)
|
||||
|
||||
# bitsandbytes config
|
||||
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
|
||||
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantoConfig
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
|
|
|
@ -82,6 +82,7 @@ _import_structure = {
|
|||
"run_hp_search_wandb",
|
||||
],
|
||||
"peft": ["PeftAdapterMixin"],
|
||||
"quanto": ["replace_with_quanto_layers"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -150,6 +151,7 @@ if TYPE_CHECKING:
|
|||
run_hp_search_wandb,
|
||||
)
|
||||
from .peft import PeftAdapterMixin
|
||||
from .quanto import replace_with_quanto_layers
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
def replace_with_quanto_layers(
|
||||
model,
|
||||
quantization_config=None,
|
||||
modules_to_not_convert=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with Quanto quantized layers.
|
||||
Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`AqlmConfig`, defaults to `None`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
modules_to_not_convert (`list`, *optional*, defaults to `None`):
|
||||
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
|
||||
converted.
|
||||
current_key_name (`list`, *optional*, defaults to `None`):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*, defaults to `None`):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
from accelerate import init_empty_weights
|
||||
from quanto import QLayerNorm, QLinear, qfloat8, qint2, qint4, qint8
|
||||
|
||||
w_mapping = {"float8": qfloat8, "int8": qint8, "int4": qint4, "int2": qint2}
|
||||
a_mapping = {None: None, "float8": qfloat8, "int8": qint8}
|
||||
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
model._modules[name] = QLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias is not None,
|
||||
dtype=module.weight.dtype,
|
||||
weights=w_mapping[quantization_config.weights],
|
||||
activations=a_mapping[quantization_config.activations],
|
||||
)
|
||||
model._modules[name].requires_grad_(False)
|
||||
has_been_replaced = True
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
if quantization_config.activations is not None:
|
||||
model._modules[name] = QLayerNorm(
|
||||
module.normalized_shape,
|
||||
module.eps,
|
||||
module.elementwise_affine,
|
||||
module.bias is not None,
|
||||
activations=a_mapping[quantization_config.activations],
|
||||
)
|
||||
has_been_replaced = True
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = replace_with_quanto_layers(
|
||||
module,
|
||||
quantization_config=quantization_config,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
current_key_name=current_key_name,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
|
@ -802,7 +802,11 @@ def _load_state_dict_into_meta_model(
|
|||
elif (
|
||||
not is_quantized
|
||||
or (not hf_quantizer.requires_parameters_quantization)
|
||||
or (not hf_quantizer.check_quantized_param(model, param, param_name, state_dict))
|
||||
or (
|
||||
not hf_quantizer.check_quantized_param(
|
||||
model, param, param_name, state_dict, param_device=param_device, device_map=device_map
|
||||
)
|
||||
)
|
||||
):
|
||||
# For backward compatibility with older versions of `accelerate` and for non-quantized params
|
||||
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
|
||||
|
@ -3728,6 +3732,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if hf_quantizer is not None:
|
||||
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
|
||||
|
||||
# retrieve weights on meta device and put them back on CPU.
|
||||
# This is not ideal in terms of memory, but if we don't do that not, we can't initialize them in the next step
|
||||
if low_cpu_mem_usage:
|
||||
|
|
|
@ -22,12 +22,14 @@ from ..utils.quantization_config import (
|
|||
GPTQConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
QuantoConfig,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
||||
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
||||
from .quantizer_gptq import GptqHfQuantizer
|
||||
from .quantizer_quanto import QuantoHfQuantizer
|
||||
|
||||
|
||||
AUTO_QUANTIZER_MAPPING = {
|
||||
|
@ -36,6 +38,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|||
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
|
||||
"gptq": GptqHfQuantizer,
|
||||
"aqlm": AqlmHfQuantizer,
|
||||
"quanto": QuantoHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
|
@ -44,6 +47,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"aqlm": AqlmConfig,
|
||||
"quanto": QuantoConfig,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
|
@ -99,6 +99,16 @@ class HfQuantizer(ABC):
|
|||
"""
|
||||
return torch_dtype
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
"""
|
||||
Override this method if you want to adjust the `missing_keys`.
|
||||
|
||||
Args:
|
||||
missing_keys (`List[str]`, *optional*):
|
||||
The list of missing keys in the checkpoint compared to the state dict of the model
|
||||
"""
|
||||
return missing_keys
|
||||
|
||||
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
|
||||
"""
|
||||
returns dtypes for modules that are not quantized - used for the computation of the device_map in case
|
||||
|
@ -111,6 +121,7 @@ class HfQuantizer(ABC):
|
|||
torch_dtype (`torch.dtype`):
|
||||
The dtype passed in `from_pretrained` method.
|
||||
"""
|
||||
|
||||
return {
|
||||
name: torch_dtype
|
||||
for name, _ in model.named_parameters()
|
||||
|
@ -122,7 +133,12 @@ class HfQuantizer(ABC):
|
|||
return max_memory
|
||||
|
||||
def check_quantized_param(
|
||||
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
checks if a loaded state_dict component is part of quantized param + some validation; only defined if
|
||||
|
|
|
@ -116,7 +116,12 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
|||
)
|
||||
|
||||
def check_quantized_param(
|
||||
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
|
|
@ -134,7 +134,12 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
|||
return torch.int8
|
||||
|
||||
def check_quantized_param(
|
||||
self, model: "PreTrainedModel", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any]
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
):
|
||||
import bitsandbytes as bnb
|
||||
|
||||
|
|
|
@ -0,0 +1,198 @@
|
|||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .base import HfQuantizer
|
||||
from .quantizers_utils import get_module_from_name
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_accelerate_available, is_quanto_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import QuantoConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantoHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer for the quanto library
|
||||
"""
|
||||
|
||||
required_packages = ["quanto", "accelerate"]
|
||||
requires_parameters_quantization = True
|
||||
requires_calibration = False
|
||||
|
||||
def __init__(self, quantization_config: QuantoConfig, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker
|
||||
"""
|
||||
if self.quantization_config.activations is not None and not self.pre_quantized:
|
||||
raise ValueError(
|
||||
"We don't support quantizing the activations with transformers library."
|
||||
"Use quanto library for more complex use cases such as activations quantization, calibration and quantization aware training."
|
||||
)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_quanto_available():
|
||||
raise ImportError("Loading a quanto quantized model requires quanto library (`pip install quanto`)")
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Loading a quanto quantized model requires accelerate library (`pip install quanto`)")
|
||||
|
||||
def update_device_map(self, device_map):
|
||||
if device_map is None:
|
||||
device_map = {"": "cpu"}
|
||||
logger.info(
|
||||
"The device_map was not initialized. "
|
||||
"Setting device_map to {'':'cpu'}. "
|
||||
"If you want to use the model for inference, please set device_map ='auto'"
|
||||
)
|
||||
return device_map
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]:
|
||||
import quanto
|
||||
|
||||
not_missing_keys = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, quanto.QModuleMixin):
|
||||
for missing in missing_keys:
|
||||
if (
|
||||
(name in missing or name in f"{prefix}.{missing}")
|
||||
and not missing.endswith(".weight")
|
||||
and not missing.endswith(".bias")
|
||||
):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
def check_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a parameter needs to be quantized.
|
||||
"""
|
||||
import quanto
|
||||
|
||||
device_map = kwargs.get("device_map", None)
|
||||
param_device = kwargs.get("param_device", None)
|
||||
# we don't quantize the model if the module is going to be offloaded to the cpu
|
||||
if device_map is not None and param_device is not None:
|
||||
device_map_values = set(device_map.values())
|
||||
if param_device == "cpu" and len(device_map_values) > 1:
|
||||
if not (device_map_values == {"cpu"} or device_map_values == {"cpu", "disk"}):
|
||||
return False
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
# We only quantize the weights and the bias is not quantized.
|
||||
if isinstance(module, quanto.QModuleMixin) and "weight" in tensor_name:
|
||||
# if the weights are quantized, don't need to recreate it again with `create_quantized_param`
|
||||
return not module.frozen
|
||||
else:
|
||||
return False
|
||||
|
||||
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
|
||||
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
|
||||
return max_memory
|
||||
|
||||
def create_quantized_param(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
param_value: "torch.Tensor",
|
||||
param_name: str,
|
||||
target_device: "torch.device",
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create the quantized parameter by calling .freeze() after setting it to the module.
|
||||
"""
|
||||
from accelerate.utils import set_module_tensor_to_device
|
||||
|
||||
set_module_tensor_to_device(model, param_name, target_device, param_value)
|
||||
module, _ = get_module_from_name(model, param_name)
|
||||
module.freeze()
|
||||
module.weight.requires_grad = False
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.27.0"):
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
mapping = {
|
||||
"int8": torch.int8,
|
||||
"float8": CustomDtype.FP8,
|
||||
"int4": CustomDtype.INT4,
|
||||
"int2": CustomDtype.INT2,
|
||||
}
|
||||
target_dtype = mapping[self.quantization_config.weights]
|
||||
return target_dtype
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are using `device_map='auto'` on a quanto quantized model. To automatically compute"
|
||||
" the appropriate device map, you should upgrade your `accelerate` library,"
|
||||
"`pip install --upgrade accelerate` or install it from source."
|
||||
)
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self, model: "PreTrainedModel", keep_in_fp32_modules: List[str] = [], **kwargs
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert, replace_with_quanto_layers
|
||||
|
||||
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
|
||||
if self.quantization_config.modules_to_not_convert is None:
|
||||
self.modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert
|
||||
|
||||
if not isinstance(self.modules_to_not_convert, list):
|
||||
self.modules_to_not_convert = [self.modules_to_not_convert]
|
||||
|
||||
self.modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
||||
model, _ = replace_with_quanto_layers(
|
||||
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model):
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
|
@ -89,6 +89,7 @@ from .utils import (
|
|||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_quanto_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
|
@ -1043,6 +1044,13 @@ def require_auto_awq(test_case):
|
|||
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
|
||||
|
||||
|
||||
def require_quanto(test_case):
|
||||
"""
|
||||
Decorator for quanto dependency
|
||||
"""
|
||||
return unittest.skipUnless(is_quanto_available(), "test requires quanto")(test_case)
|
||||
|
||||
|
||||
def require_phonemizer(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires phonemizer
|
||||
|
|
|
@ -152,6 +152,7 @@ from .import_utils import (
|
|||
is_pytesseract_available,
|
||||
is_pytest_available,
|
||||
is_pytorch_quantization_available,
|
||||
is_quanto_available,
|
||||
is_rjieba_available,
|
||||
is_sacremoses_available,
|
||||
is_safetensors_available,
|
||||
|
|
|
@ -131,6 +131,7 @@ _optimum_available = _is_package_available("optimum")
|
|||
_auto_gptq_available = _is_package_available("auto_gptq")
|
||||
# `importlib.metadata.version` doesn't work with `awq`
|
||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||
_quanto_available = _is_package_available("quanto")
|
||||
_pandas_available = _is_package_available("pandas")
|
||||
_peft_available = _is_package_available("peft")
|
||||
_phonemizer_available = _is_package_available("phonemizer")
|
||||
|
@ -788,6 +789,10 @@ def is_auto_awq_available():
|
|||
return _auto_awq_available
|
||||
|
||||
|
||||
def is_quanto_available():
|
||||
return _quanto_available
|
||||
|
||||
|
||||
def is_auto_gptq_available():
|
||||
return _auto_gptq_available
|
||||
|
||||
|
|
|
@ -39,6 +39,7 @@ class QuantizationMethod(str, Enum):
|
|||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
QUANTO = "quanto"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
|
@ -836,3 +837,44 @@ class AqlmConfig(QuantizationConfigMixin):
|
|||
|
||||
if self.linear_weights_not_to_quantize is None:
|
||||
self.linear_weights_not_to_quantize = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantoConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `quanto`.
|
||||
|
||||
Args:
|
||||
weights (`str`, *optional*, defaults to `"int8"`):
|
||||
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
||||
activations (`str`, *optional*):
|
||||
The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weights="int8",
|
||||
activations=None,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights = weights
|
||||
self.activations = activations
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
accepted_weights = ["float8", "int8", "int4", "int2"]
|
||||
accepted_activations = [None, "int8", "float8"]
|
||||
if self.weights not in accepted_weights:
|
||||
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")
|
||||
if self.activations not in accepted_activations:
|
||||
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
|
||||
|
|
|
@ -0,0 +1,431 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
|
||||
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if is_quanto_available():
|
||||
from quanto import QLayerNorm, QLinear
|
||||
|
||||
from transformers.integrations.quanto import replace_with_quanto_layers
|
||||
|
||||
|
||||
class QuantoConfigTest(unittest.TestCase):
|
||||
def test_attributes(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_quanto
|
||||
@require_accelerate
|
||||
class QuantoTestIntegration(unittest.TestCase):
|
||||
model_id = "facebook/opt-350m"
|
||||
|
||||
def setUp(self):
|
||||
config = AutoConfig.from_pretrained(self.model_id)
|
||||
with init_empty_weights():
|
||||
self.model = AutoModelForCausalLM.from_config(config)
|
||||
self.nb_linear = 0
|
||||
self.nb_layernorm = 0
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
self.nb_linear += 1
|
||||
elif isinstance(module, torch.nn.LayerNorm):
|
||||
self.nb_layernorm += 1
|
||||
|
||||
def test_weight_only_quantization_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly when using weight only quantization
|
||||
"""
|
||||
|
||||
# Try with weight only quantization
|
||||
quantization_config = QuantoConfig(weights="int8", activations=None)
|
||||
self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config)
|
||||
|
||||
nb_qlinear = 0
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, QLinear):
|
||||
nb_qlinear += 1
|
||||
|
||||
self.assertEqual(self.nb_linear, nb_qlinear)
|
||||
|
||||
def test_weight_and_activation_quantization_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly when using weight + activation quantization
|
||||
"""
|
||||
|
||||
# Try with weight + activation quantization
|
||||
quantization_config = QuantoConfig(weights="int8", activations="int8")
|
||||
self.model, _ = replace_with_quanto_layers(self.model, quantization_config=quantization_config)
|
||||
|
||||
nb_qlinear = 0
|
||||
nb_qlayernorm = 0
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, QLinear):
|
||||
nb_qlinear += 1
|
||||
if isinstance(module, QLayerNorm):
|
||||
nb_qlayernorm += 1
|
||||
|
||||
self.assertEqual(self.nb_linear, nb_qlinear)
|
||||
self.assertEqual(self.nb_layernorm, nb_qlayernorm)
|
||||
|
||||
def test_conversion_with_modules_to_not_convert(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly when specifying modules_to_not_convert argument
|
||||
"""
|
||||
|
||||
# Try with weight + activatioin quantization
|
||||
quantization_config = QuantoConfig(weights="int8", activations="int8")
|
||||
self.model, _ = replace_with_quanto_layers(
|
||||
self.model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
|
||||
)
|
||||
|
||||
nb_qlinear = 0
|
||||
nb_qlayernorm = 0
|
||||
for module in self.model.modules():
|
||||
if isinstance(module, QLinear):
|
||||
nb_qlinear += 1
|
||||
if isinstance(module, QLayerNorm):
|
||||
nb_qlayernorm += 1
|
||||
|
||||
self.assertEqual(self.nb_linear - 1, nb_qlinear)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_quanto
|
||||
@require_accelerate
|
||||
class QuantoQuantizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test 8-bit weights only quantization
|
||||
"""
|
||||
|
||||
model_name = "bigscience/bloom-560m"
|
||||
|
||||
weights = "int8"
|
||||
activations = None
|
||||
device_map = "cpu"
|
||||
|
||||
input_text = "Hello my name is"
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer and I"
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
quantization_config = QuantoConfig(
|
||||
weights=self.weights,
|
||||
activations=self.activations,
|
||||
)
|
||||
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device_map,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.have_accelerate_hooks = (
|
||||
getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1
|
||||
)
|
||||
|
||||
def check_inference_correctness(self, model, device):
|
||||
r"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
Given that we are operating on small numbers + the testing model is relatively small, we might not get
|
||||
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
|
||||
"""
|
||||
if not self.have_accelerate_hooks:
|
||||
model.to(device)
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(device), max_new_tokens=10)
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generate_quality_cpu(self):
|
||||
"""
|
||||
Simple test to check the quality of the model on cpu by comparing the generated tokens with the expected tokens
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model, "cpu")
|
||||
|
||||
def test_generate_quality_cuda(self):
|
||||
"""
|
||||
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model, "cuda")
|
||||
|
||||
def test_quantized_model_layers(self):
|
||||
from quanto import QBitsTensor, QModuleMixin, QTensor
|
||||
|
||||
"""
|
||||
Suite of simple test to check if the layers are quantized and are working properly
|
||||
"""
|
||||
# Test the type of the quantized layer
|
||||
self.assertTrue(isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value, QModuleMixin))
|
||||
self.assertTrue(
|
||||
isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QTensor)
|
||||
)
|
||||
if self.weights == "int4":
|
||||
self.assertTrue(
|
||||
isinstance(self.quantized_model.transformer.h[0].self_attention.query_key_value.weight, QBitsTensor)
|
||||
)
|
||||
|
||||
# check that the lm_head was indeed not quantized, just like bnb
|
||||
self.assertTrue(
|
||||
isinstance(self.quantized_model.lm_head, torch.nn.Linear)
|
||||
and not isinstance(self.quantized_model.lm_head, QModuleMixin)
|
||||
)
|
||||
if self.device_map in ["cpu", "cuda"]:
|
||||
self.assertEqual(
|
||||
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type,
|
||||
self.device_map,
|
||||
)
|
||||
self.quantized_model.to(0)
|
||||
self.assertEqual(
|
||||
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
|
||||
)
|
||||
|
||||
def test_serialization_bin(self):
|
||||
"""
|
||||
Test the serialization, the loading and the inference of the quantized weights
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertRaises(ValueError) as e:
|
||||
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception))
|
||||
# TODO: replace by the following when it works
|
||||
# quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
|
||||
# tmpdirname, torch_dtype=torch.float32, device_map="cpu"
|
||||
# )
|
||||
# self.check_inference_correctness(quantized_model_from_saved, device="cuda")
|
||||
|
||||
def test_serialization_safetensors(self):
|
||||
"""
|
||||
Test the serialization, the loading and the inference of the quantized weights
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with self.assertRaises(ValueError) as e:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
self.assertIn("The model is quantized with quanto and is not serializable", str(e.exception))
|
||||
# quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
|
||||
# tmpdirname, torch_dtype=torch.float32, device_map="cpu"
|
||||
# )
|
||||
# self.check_inference_correctness(quantized_model_from_saved, device="cuda")
|
||||
|
||||
def check_same_model(self, model1, model2):
|
||||
d0 = dict(model1.named_parameters())
|
||||
d1 = dict(model2.named_parameters())
|
||||
self.assertTrue(d0.keys() == d1.keys())
|
||||
for k in d0.keys():
|
||||
self.assertTrue(d0[k].shape == d1[k].shape)
|
||||
self.assertTrue(d0[k].device.type == d1[k].device.type)
|
||||
self.assertTrue(d0[k].device == d1[k].device)
|
||||
self.assertTrue(d0[k].dtype == d1[k].dtype)
|
||||
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||
|
||||
def test_compare_with_quanto(self):
|
||||
from quanto import freeze, qint4, qint8, quantize
|
||||
|
||||
w_mapping = {"int8": qint8, "int4": qint4}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device_map,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
# we do not quantize the lm_head since we don't do that in transformers
|
||||
quantize(model.transformer, weights=w_mapping[self.weights])
|
||||
freeze(model.transformer)
|
||||
self.check_same_model(model, self.quantized_model)
|
||||
self.check_inference_correctness(model, device="cuda")
|
||||
|
||||
@unittest.skip
|
||||
def test_load_from_quanto_saved(self):
|
||||
from quanto import freeze, qint4, qint8, quantize
|
||||
|
||||
from transformers import QuantoConfig
|
||||
|
||||
w_mapping = {"int8": qint8, "int4": qint4}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device_map,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
# we do not quantize the lm_head since we don't do that in transformers
|
||||
quantize(model.transformer, weights=w_mapping[self.weights])
|
||||
freeze(model.transformer)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.config.quantization_config = QuantoConfig(
|
||||
weights=self.weights, activations=self.activations, modules_to_not_convert=["lm_head"]
|
||||
)
|
||||
model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdirname,
|
||||
device_map=self.device_map,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
self.check_same_model(model, quantized_model_from_saved)
|
||||
self.check_inference_correctness(quantized_model_from_saved, device="cuda")
|
||||
|
||||
|
||||
class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
|
||||
device_map = {
|
||||
"transformer.word_embeddings": 0,
|
||||
"transformer.word_embeddings_layernorm": 0,
|
||||
"transformer.ln_f": 0,
|
||||
"transformer.h.0": 0,
|
||||
"transformer.h.1": 0,
|
||||
"transformer.h.2": 0,
|
||||
"transformer.h.3": 0,
|
||||
"transformer.h.4": 0,
|
||||
"transformer.h.5": 0,
|
||||
"transformer.h.6": 0,
|
||||
"transformer.h.7": 0,
|
||||
"transformer.h.8": 0,
|
||||
"transformer.h.9": 0,
|
||||
"transformer.h.10": 0,
|
||||
"transformer.h.11": 0,
|
||||
"transformer.h.12": 0,
|
||||
"transformer.h.13": 0,
|
||||
"transformer.h.14": 0,
|
||||
"transformer.h.15": 0,
|
||||
"transformer.h.16": 0,
|
||||
"transformer.h.17": 0,
|
||||
"transformer.h.18": 0,
|
||||
"transformer.h.19": 0,
|
||||
"transformer.h.20": 0,
|
||||
"transformer.h.21": 0,
|
||||
"transformer.h.22": "cpu",
|
||||
"transformer.h.23": "disk",
|
||||
"lm_head": 0,
|
||||
}
|
||||
|
||||
# the execution device is a gpu
|
||||
def test_generate_quality_cpu(self):
|
||||
pass
|
||||
|
||||
# we can't save offloaded values
|
||||
def test_serialization_bin(self):
|
||||
pass
|
||||
|
||||
def test_serialization_safetensors(self):
|
||||
pass
|
||||
|
||||
def test_compare_with_quanto(self):
|
||||
pass
|
||||
|
||||
def test_load_from_quanto_saved(self):
|
||||
pass
|
||||
|
||||
def test_check_offload_quantized(self):
|
||||
"""
|
||||
We check that we have unquantized value in the cpu and in the disk
|
||||
"""
|
||||
import quanto
|
||||
|
||||
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
|
||||
"weight"
|
||||
]
|
||||
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
|
||||
"weight"
|
||||
]
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor))
|
||||
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor))
|
||||
if self.weights == "int4":
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor))
|
||||
self.assertTrue(
|
||||
isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)
|
||||
)
|
||||
|
||||
|
||||
@unittest.skip("Skipping test class because serialization is not supported yet")
|
||||
class QuantoQuantizationSerializationTest(QuantoQuantizationTest):
|
||||
"""
|
||||
Perform the same tests as in QuantoQuantizationTest but with a serialized model.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
quantization_config = QuantoConfig(
|
||||
weights=self.weights,
|
||||
activations=self.activations,
|
||||
)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device_map,
|
||||
quantization_config=quantization_config,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
|
||||
self.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdirname, torch_dtype=torch.float32, device_map=self.device_map
|
||||
)
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
self.have_accelerate_hooks = (
|
||||
getattr(self.quantized_model, "hf_device_map", False) and len(self.quantized_model.hf_device_map) > 1
|
||||
)
|
||||
|
||||
|
||||
@unittest.skip("Skipping test class because serialization is not supported yet")
|
||||
class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest):
|
||||
"""
|
||||
Perform the same tests as in QuantoQuantizationTest but with model on cuda
|
||||
"""
|
||||
|
||||
device_map = "cuda:0"
|
||||
|
||||
|
||||
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
@unittest.skip("Skipping test class because serialization is not supported yet")
|
||||
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a young man from the Philippines"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||
def test_quantize_activation(self):
|
||||
quantization_config = QuantoConfig(
|
||||
weights="int8",
|
||||
activations="int8",
|
||||
)
|
||||
with self.assertRaises(ValueError) as e:
|
||||
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
|
||||
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
|
Loading…
Reference in New Issue