I-BERT model support (#10153)

* IBertConfig, IBertTokentizer added

* IBert Model names moified

* tokenizer bugfix

* embedding -> QuantEmbedding

* quant utils added

* quant_mode added to configuration

* QuantAct added, Embedding layer + QuantAct addition

* QuantAct added

* unused path removed, QKV quantized

* self attention layer all quantized, except softmax

* temporarl commit

* all liner layers quantized

* quant_utils bugfix

* bugfix: requantization missing

* IntGELU added

* IntSoftmax added

* LayerNorm implemented

* LayerNorm implemented all

* names changed: roberta->ibert

* config not inherit from ROberta

* No support for CausalLM

* static quantization added, quantize_model.py removed

* import modules uncommented

* copyrights fixed

* minor bugfix

* quant_modules, quant_utils merged as one file

* import * fixed

* unused runfile removed

* make style run

* configutration.py docstring fixed

* refactoring: comments removed, function name fixed

* unused dependency removed

* typo fixed

* comments(Copied from), assertion string added

* refactoring: super(..) -> super(), etc.

* refactoring

* refarctoring

* make style

* refactoring

* cuda -> to(x.device)

* weight initialization removed

* QuantLinear set_param removed

* QuantEmbedding set_param removed

* IntLayerNorm set_param removed

* assert string added

* assertion error message fixed

* is_decoder removed

* enc-dec arguments/functions removed

* Converter removed

* quant_modules docstring fixed

* conver_slow_tokenizer rolled back

* quant_utils docstring fixed

* unused aruments e.g. use_cache removed from config

* weight initialization condition fixed

* x_min, x_max initialized with small values to avoid div-zero exceptions

* testing code for ibert

* test emb, linear, gelu, softmax added

* test ln and act added

* style reformatted

* force_dequant added

* error tests overrided

* make style

* Style + Docs

* force dequant tests added

* Fix fast tokenizer in init

* Fix doc

* Remove space

* docstring, IBertConfig, chunk_size

* test_modeling_ibert refactoring

* quant_modules.py refactoring

* e2e integration test added

* tokenizers removed

* IBertConfig added to tokenizer_auto.py

* bugfix

* fix docs & test

* fix style num 2

* final fixes

Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sehoon Kim 2021-02-26 00:06:42 +09:00 committed by GitHub
parent cb38ffcc5e
commit 63645b3b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 3279 additions and 0 deletions

View File

@ -263,6 +263,8 @@ TensorFlow and/or Flax.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
@ -405,6 +407,7 @@ TensorFlow and/or Flax.
model_doc/fsmt
model_doc/funnel
model_doc/herbert
model_doc/ibert
model_doc/layoutlm
model_doc/led
model_doc/longformer

View File

@ -0,0 +1,88 @@
..
Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
I-BERT
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The I-BERT model was proposed in `I-BERT: Integer-only BERT Quantization <https://arxiv.org/abs/2006.10220>`__ by
Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney and Kurt Keutzer. It's a quantized version of RoBERTa running
inference up to four times faster.
The abstract from the paper is the following:
*Transformer based models, like BERT and RoBERTa, have achieved state-of-the-art results in many Natural Language
Processing tasks. However, their memory footprint, inference latency, and power consumption are prohibitive for
efficient inference at the edge, and even at the data center. While quantization can be a viable solution for this,
previous work on quantizing Transformer based models use floating-point arithmetic during inference, which cannot
efficiently utilize integer-only logical units such as the recent Turing Tensor Cores, or traditional integer-only ARM
processors. In this work, we propose I-BERT, a novel quantization scheme for Transformer based models that quantizes
the entire inference with integer-only arithmetic. Based on lightweight integer-only approximation methods for
nonlinear operations, e.g., GELU, Softmax, and Layer Normalization, I-BERT performs an end-to-end integer-only BERT
inference without any floating point calculation. We evaluate our approach on GLUE downstream tasks using
RoBERTa-Base/Large. We show that for both cases, I-BERT achieves similar (and slightly higher) accuracy as compared to
the full-precision baseline. Furthermore, our preliminary implementation of I-BERT shows a speedup of 2.4 - 4.0x for
INT8 inference on a T4 GPU system as compared to FP32 inference. The framework has been developed in PyTorch and has
been open-sourced.*
The original code can be found `here <https://github.com/kssteven418/I-BERT>`__.
IBertConfig
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertConfig
:members:
IBertModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertModel
:members: forward
IBertForMaskedLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertForMaskedLM
:members: forward
IBertForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertForSequenceClassification
:members: forward
IBertForMultipleChoice
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertForMultipleChoice
:members: forward
IBertForTokenClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertForTokenClassification
:members: forward
IBertForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.IBertForQuestionAnswering
:members: forward

View File

@ -182,6 +182,7 @@ _import_structure = {
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
"models.herbert": ["HerbertTokenizer"],
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
@ -613,6 +614,20 @@ if is_torch_available():
"load_tf_weights_in_gpt2",
]
)
_import_structure["models.ibert"].extend(
[
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertForMaskedLM",
"IBertForMultipleChoice",
"IBertForQuestionAnswering",
"IBertForSequenceClassification",
"IBertForTokenClassification",
"IBertLayer",
"IBertModel",
"IBertPreTrainedModel",
"load_tf_weights_in_ibert",
]
)
_import_structure["models.layoutlm"].extend(
[
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
@ -1328,6 +1343,7 @@ if TYPE_CHECKING:
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
from .models.herbert import HerbertTokenizer
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
@ -1710,6 +1726,15 @@ if TYPE_CHECKING:
GPT2PreTrainedModel,
load_tf_weights_in_gpt2,
)
from .models.ibert import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
from .models.layoutlm import (
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
LayoutLMForMaskedLM,

View File

@ -40,6 +40,7 @@ from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
@ -110,6 +111,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
]
for key, value, in pretrained_map.items()
)
@ -123,6 +125,7 @@ CONFIG_MAPPING = OrderedDict(
("led", LEDConfig),
("blenderbot-small", BlenderbotSmallConfig),
("retribert", RetriBertConfig),
("ibert", IBertConfig),
("mt5", MT5Config),
("t5", T5Config),
("mobilebert", MobileBertConfig),
@ -173,6 +176,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("led", "LED"),
("blenderbot-small", "BlenderbotSmall"),
("retribert", "RetriBERT"),
("ibert", "I-BERT"),
("t5", "T5"),
("mobilebert", "MobileBERT"),
("distilbert", "DistilBERT"),

View File

@ -129,6 +129,14 @@ from ..funnel.modeling_funnel import (
FunnelModel,
)
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
from ..ibert.modeling_ibert import (
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
from ..layoutlm.modeling_layoutlm import (
LayoutLMForMaskedLM,
LayoutLMForSequenceClassification,
@ -270,6 +278,7 @@ from .configuration_auto import (
FSMTConfig,
FunnelConfig,
GPT2Config,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
@ -347,6 +356,7 @@ MODEL_MAPPING = OrderedDict(
(MPNetConfig, MPNetModel),
(TapasConfig, TapasModel),
(MarianConfig, MarianModel),
(IBertConfig, IBertModel),
]
)
@ -379,6 +389,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(FunnelConfig, FunnelForPreTraining),
(MPNetConfig, MPNetForMaskedLM),
(TapasConfig, TapasForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@ -418,6 +429,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@ -476,6 +488,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
(TapasConfig, TapasForMaskedLM),
(DebertaConfig, DebertaForMaskedLM),
(DebertaV2Config, DebertaV2ForMaskedLM),
(IBertConfig, IBertForMaskedLM),
]
)
@ -529,6 +542,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
(TransfoXLConfig, TransfoXLForSequenceClassification),
(MPNetConfig, MPNetForSequenceClassification),
(TapasConfig, TapasForSequenceClassification),
(IBertConfig, IBertForSequenceClassification),
]
)
@ -558,6 +572,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
(MPNetConfig, MPNetForQuestionAnswering),
(DebertaConfig, DebertaForQuestionAnswering),
(DebertaV2Config, DebertaV2ForQuestionAnswering),
(IBertConfig, IBertForQuestionAnswering),
]
)
@ -591,6 +606,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
(MPNetConfig, MPNetForTokenClassification),
(DebertaConfig, DebertaForTokenClassification),
(DebertaV2Config, DebertaV2ForTokenClassification),
(IBertConfig, IBertForTokenClassification),
]
)
@ -613,6 +629,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
(FlaubertConfig, FlaubertForMultipleChoice),
(FunnelConfig, FunnelForMultipleChoice),
(MPNetConfig, MPNetForMultipleChoice),
(IBertConfig, IBertForMultipleChoice),
]
)

View File

@ -75,6 +75,7 @@ from .configuration_auto import (
FSMTConfig,
FunnelConfig,
GPT2Config,
IBertConfig,
LayoutLMConfig,
LEDConfig,
LongformerConfig,
@ -244,6 +245,7 @@ TOKENIZER_MAPPING = OrderedDict(
(TapasConfig, (TapasTokenizer, None)),
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
]
)

View File

@ -0,0 +1,69 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
_import_structure = {
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
}
if is_torch_available():
_import_structure["modeling_ibert"] = [
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
"IBertForMaskedLM",
"IBertForMultipleChoice",
"IBertForQuestionAnswering",
"IBertForSequenceClassification",
"IBertForTokenClassification",
"IBertModel",
]
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
if is_torch_available():
from .modeling_ibert import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
else:
import importlib
import os
import sys
class _LazyModule(_BaseLazyModule):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
__file__ = globals()["__file__"]
__path__ = [os.path.dirname(__file__)]
def _get_module(self, module_name: str):
return importlib.import_module("." + module_name, self.__name__)
sys.modules[__name__] = _LazyModule(__name__, _import_structure)

View File

@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" I-BERT configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
"kssteven/ibert-roberta-base": "https://huggingface.co/kssteven/ibert-roberta-base/resolve/main/config.json",
"kssteven/ibert-roberta-large": "https://huggingface.co/kssteven/ibert-roberta-large/resolve/main/config.json",
"kssteven/ibert-roberta-large-mnli": "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json",
}
class IBertConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of a :class:`~transformers.IBertModel`. It is used to
instantiate a I-BERT model according to the specified arguments,
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Args:
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the I-BERT model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.IBertModel`
hidden_size (:obj:`int`, `optional`, defaults to 768):
Dimensionality of the encoder layers and the pooler layer.
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
Number of hidden layers in the Transformer encoder.
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout ratio for the attention probabilities.
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.IBertModel`
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
<https://arxiv.org/abs/2009.13658>`__.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to quantize the model or not.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize specific nonlinear layer. Dequatized layers are then executed with full precision.
:obj:`"none"`, :obj:`"gelu"`, :obj:`"softmax"`, :obj:`"layernorm"` and :obj:`"nonlinear"` are supported. As
deafult, it is set as :obj:`"none"`, which does not dequantize any layers. Please specify :obj:`"gelu"`,
:obj:`"softmax"`, or :obj:`"layernorm"` to dequantize GELU, Softmax, or LayerNorm, respectively.
:obj:`"nonlinear"` will dequantize all nonlinear layers, i.e., GELU, Softmax, and LayerNorm.
"""
model_type = "ibert"
def __init__(
self,
vocab_size=30522,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
position_embedding_type="absolute",
quant_mode=False,
force_dequant="none",
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_act = hidden_act
self.intermediate_size = intermediate_size
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.position_embedding_type = position_embedding_type
self.quant_mode = quant_mode
self.force_dequant = force_dequant

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,829 @@
# coding=utf-8
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import decimal
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from ...utils import logging
logger = logging.get_logger(__name__)
class QuantEmbedding(nn.Module):
"""
Quantized version of :obj:`torch.nn.Embedding`. Adds quantization-specific arguments on top of
:obj:`torch.nn.Embedding`.
Args:
weight_bit (:obj:`int`, `optiona`l, defaults to :obj:`8`):
Bitwidth for the quantized weight.
momentum (:obj:`float`, `optional, defaults to :obj:`0.95`):
Momentum for updating the activation quantization range.
quant_mode (:obj:`bool`, `optional, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
_weight=None,
weight_bit=8,
momentum=0.95,
quant_mode=False,
):
super().__init__()
self.num_ = num_embeddings
self.dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
self.register_buffer("weight_scaling_factor", torch.zeros(1))
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
self.weight_bit = weight_bit
self.momentum = momentum
self.quant_mode = quant_mode
self.percentile_mode = False
self.weight_function = SymmetricQuantFunction.apply
def forward(self, x, positions=None, incremental_state=None):
if not self.quant_mode:
return (
F.embedding(
x,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
),
None,
)
w = self.weight
w_transform = w.data.detach()
w_min = w_transform.min().expand(1)
w_max = w_transform.max().expand(1)
self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
self.weight_integer = self.weight_function(
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
)
emb_int = F.embedding(
x,
self.weight_integer,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
class QuantAct(nn.Module):
"""
Quantizes the given activation.
Args:
activation_bit (:obj:`int`):
Bitwidth for the quantized activation.
act_range_momentum (:obj:`float`, `optional`, defaults to :obj:`0.95`):
Momentum for updating the activation quantization range.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to or not use channel-wise quantization.
channel_len (:obj:`int`, `optional`, defaults to :obj:`None`):
Specify the channel length when set the `per_channel` True.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
super().__init__()
self.activation_bit = activation_bit
self.act_range_momentum = act_range_momentum
self.quant_mode = quant_mode
self.per_channel = per_channel
self.percentile = False
self.act_function = SymmetricQuantFunction.apply
if not self.per_channel:
self.register_buffer("x_min", torch.zeros(1))
self.register_buffer("x_max", torch.zeros(1))
self.register_buffer("act_scaling_factor", torch.zeros(1))
self.x_min -= 1e-5
self.x_max += 1e-5
else:
raise NotImplementedError("per-channel mode is not currently supported for activation.")
def __repr__(self):
return (
"{0}(activation_bit={1}, "
"quant_mode: {2}, Act_min: {3:.2f}, "
"Act_max: {4:.2f})".format(
self.__class__.__name__, self.activation_bit, self.quant_mode, self.x_min.item(), self.x_max.item()
)
)
def forward(
self,
x,
pre_act_scaling_factor=None,
identity=None,
identity_scaling_factor=None,
specified_min=None,
specified_max=None,
):
x_act = x if identity is None else identity + x
# collect runnng stats if traiing
if self.training:
assert not self.percentile, "percentile mode is not currently supported for activation."
assert not self.per_channel, "per-channel mode is not currently supported for activation."
x_min = x_act.data.min()
x_max = x_act.data.max()
assert (
x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0
), "NaN detected when computing min/max of the activation"
# Initialization
if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
self.x_min = self.x_min + x_min
self.x_max = self.x_max + x_max
# exponential moving average (EMA)
# use momentum to prevent the quantized values change greatly every iteration
elif self.act_range_momentum == -1:
self.x_min = torch.min(self.x_min, x_min)
self.x_max = torch.max(self.x_max, x_max)
else:
self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
if not self.quant_mode:
return x_act, None
x_min = self.x_min if specified_min is None else specified_min
x_max = self.x_max if specified_max is None else specified_max
self.act_scaling_factor = symmetric_linear_quantization_params(
self.activation_bit, x_min, x_max, per_channel=self.per_channel
)
if pre_act_scaling_factor is None:
# this is for the input quantization
quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
else:
quant_act_int = FixedPointMul.apply(
x,
pre_act_scaling_factor,
self.activation_bit,
self.act_scaling_factor,
identity,
identity_scaling_factor,
)
correct_output_scale = self.act_scaling_factor.view(-1)
return quant_act_int * correct_output_scale, self.act_scaling_factor
class QuantLinear(nn.Module):
"""
Quantized version of :obj:`torch.nn.Linear`. Adds quantization-specific arguments on top of :obj:`torch.nn.Linear`.
Args:
weight_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
Bitwidth for the quantized weight.
bias_bit (:obj:`int`, `optional`, defaults to :obj:`32`):
Bitwidth for the quantized bias.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use channel-wise quantization.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
"""
def __init__(
self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
if bias:
self.bias = nn.Parameter(torch.zeros(out_features))
self.register_buffer("bias_integer", torch.zeros_like(self.bias))
self.weight_bit = weight_bit
self.quant_mode = quant_mode
self.per_channel = per_channel
self.bias_bit = bias_bit
self.quant_mode = quant_mode
self.percentile_mode = False
self.weight_function = SymmetricQuantFunction.apply
def __repr__(self):
s = super().__repr__()
s = "(" + s + " weight_bit={}, quant_mode={})".format(self.weight_bit, self.quant_mode)
return s
def forward(self, x, prev_act_scaling_factor=None):
if not self.quant_mode:
return F.linear(x, weight=self.weight, bias=self.bias), None
# assert that prev_act_scaling_factor is a scalar tensor
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
"Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
"Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
)
w = self.weight
w_transform = w.data.detach()
if self.per_channel:
w_min, _ = torch.min(w_transform, dim=1, out=None)
w_max, _ = torch.max(w_transform, dim=1, out=None)
else:
w_min = w_transform.min().expand(1)
w_max = w_transform.max().expand(1)
self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
self.weight_integer = self.weight_function(
self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
)
bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
if self.bias is not None:
self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
x_int = x / prev_act_scaling_factor
return (
F.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
bias_scaling_factor,
)
class IntGELU(nn.Module):
"""
Quantized version of :obj:`torch.nn.GELU`. Adds quantization-specific arguments on top of :obj:`torch.nn.GELU`.
Args:
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "gelu" or "nonlinear" is given.
"""
def __init__(self, quant_mode=True, force_dequant="none"):
super().__init__()
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "gelu"]:
logger.info("Force dequantize gelu")
self.quant_mode = False
if not self.quant_mode:
self.activation_fn = nn.GELU()
self.k = 1.4142
self.const = 14 # dummy integer constant
self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
self.coeff[2] /= self.coeff[0]
def int_erf(self, x_int, scaling_factor):
b_int = torch.floor(self.coeff[1] / scaling_factor)
c_int = torch.floor(self.coeff[2] / scaling_factor ** 2)
sign = torch.sign(x_int)
abs_int = torch.min(torch.abs(x_int), -b_int)
y_int = sign * ((abs_int + b_int) ** 2 + c_int)
scaling_factor = scaling_factor ** 2 * self.coeff[0]
# avoid overflow
y_int = floor_ste.apply(y_int / 2 ** self.const)
scaling_factor = scaling_factor * 2 ** self.const
return y_int, scaling_factor
def forward(self, x, scaling_factor=None):
if not self.quant_mode:
return self.activation_fn(x), None
x_int = x / scaling_factor
sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
shift_int = 1.0 // sigmoid_scaling_factor
x_int = x_int * (sigmoid_int + shift_int)
scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
return x_int * scaling_factor, scaling_factor
class IntSoftmax(nn.Module):
"""
Quantized version of :obj:`torch.nn.Softmax`. Adds quantization-specific arguments on top of
:obj:`torch.nn.Softmax`.
Args:
output_bit (:obj:`int`):
Bitwidth for the layer output activation.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "softmax" or "nonlinear" is given.
"""
def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
super().__init__()
self.output_bit = output_bit
self.max_bit = 32
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "softmax"]:
logger.info("Force dequantize softmax")
self.quant_mode = False
self.act = QuantAct(16, quant_mode=self.quant_mode)
self.x0 = -0.6931 # -ln2
self.const = 30 # dummy integer constant
self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
self.coef[1] /= self.coef[0]
self.coef[2] /= self.coef[0]
def int_polynomial(self, x_int, scaling_factor):
with torch.no_grad():
b_int = torch.floor(self.coef[1] / scaling_factor)
c_int = torch.floor(self.coef[2] / scaling_factor ** 2)
z = (x_int + b_int) * x_int + c_int
scaling_factor = self.coef[0] * scaling_factor ** 2
return z, scaling_factor
def int_exp(self, x_int, scaling_factor):
with torch.no_grad():
x0_int = torch.floor(self.x0 / scaling_factor)
x_int = torch.max(x_int, self.const * x0_int)
q = floor_ste.apply(x_int / x0_int)
r = x_int - x0_int * q
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
scaling_factor = exp_scaling_factor / 2 ** self.const
return exp_int, scaling_factor
def forward(self, x, scaling_factor):
if not self.quant_mode:
return nn.Softmax(dim=-1)(x), None
x_int = x / scaling_factor
x_int_max, _ = x_int.max(dim=-1, keepdim=True)
x_int = x_int - x_int_max
exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
# Avoid overflow
exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
exp_int = exp / exp_scaling_factor
exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
factor = floor_ste.apply(2 ** self.max_bit / exp_int_sum)
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
scaling_factor = 1 / 2 ** self.output_bit
return exp_int * scaling_factor, scaling_factor
class IntLayerNorm(nn.Module):
"""
Quantized version of :obj:`torch.nn.LayerNorm`. Adds quantization-specific arguments on top of
:obj:`torch.nn.LayerNorm`.
Args:
output_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
Bitwidth for the layer output activation.
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not the layer is quantized.
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
Force dequantize the layer if either "layernorm" or "nonlinear" is given.
"""
def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
super().__init__()
self.normalized_shape = normalized_shape
self.eps = eps
self.weight = nn.Parameter(torch.zeros(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.quant_mode = quant_mode
if force_dequant in ["nonlinear", "layernorm"]:
logger.info("Force dequantize layernorm")
self.quant_mode = False
self.register_buffer("shift", torch.zeros(1))
self.output_bit = output_bit
self.max_bit = 32
self.dim_sqrt = None
self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
def set_shift(self, y_int):
with torch.no_grad():
y_sq_int = y_int ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
shift = (torch.log2(torch.sqrt(var_int / 2 ** self.max_bit)).ceil()).max()
shift_old = self.shift
self.shift = torch.max(self.shift, shift)
logger.info("Dynamic shift adjustment: {} -> {}".format(int(shift_old), int(self.shift)))
def overflow_fallback(self, y_int):
"""
This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
to avoid overflow in the subsequent runs.
"""
self.set_shift(y_int) # adjusts `self.shift`
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
y_sq_int = y_int_shifted ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
return var_int
def forward(self, x, scaling_factor=None):
if not self.quant_mode:
mean = x.mean(axis=2, keepdim=True)
y = x - mean
var = torch.mean(y ** 2, axis=2, keepdim=True)
x = y / torch.sqrt(self.eps + var)
x = x * self.weight + self.bias
return x, None
# compute sqrt of the feature dimension if it is the first run
if self.dim_sqrt is None:
n = torch.tensor(x.shape[2], dtype=torch.float)
self.dim_sqrt = torch.sqrt(n).to(x.device)
# Normalization: computes mean and variance(std)
x_int = x / scaling_factor
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
y_int = x_int - mean_int
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
y_sq_int = y_int_shifted ** 2
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
# overflow handling in training time
if self.training:
# if overflow is detected
if var_int.max() >= 2 ** self.max_bit:
var_int = self.overflow_fallback(y_int)
assert var_int.max() < 2 ** self.max_bit + 0.1, (
"Error detected in overflow handling: "
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
)
# To be replaced with integer-sqrt kernel that produces the same output
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2 ** self.shift
factor = floor_ste.apply(2 ** 31 / std_int)
y_int = floor_ste.apply(y_int * factor / 2)
scaling_factor = self.dim_sqrt / 2 ** 30
# scaling and shifting
bias = self.bias.data.detach() / (self.weight.data.detach())
bias_int = floor_ste.apply(bias / scaling_factor)
y_int = y_int + bias_int
scaling_factor = scaling_factor * self.weight
x = y_int * scaling_factor
return x, scaling_factor
def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
"""
Calculate the percentile max and min values in a given tensor
Args:
input (:obj:`torch.Tensor`):
The target tensor to calculate percentile max and min.
lower_percentile (:obj:`float`):
If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
upper_percentile (:obj:`float`):
If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
output_tensor (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, this function returns tensors, otherwise it returns values.
Returns:
:obj:`Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of `input`
"""
input_length = input.shape[0]
lower_index = round(input_length * (1 - lower_percentile * 0.01))
upper_index = round(input_length * upper_percentile * 0.01)
upper_bound = torch.kthvalue(input, k=upper_index).values
if lower_percentile == 0:
lower_bound = upper_bound * 0
# lower_index += 1
else:
lower_bound = -torch.kthvalue(-input, k=lower_index).values
if not output_tensor:
lower_bound = lower_bound.item()
upper_bound = upper_bound.item()
return lower_bound, upper_bound
def linear_quantize(input, scale, zero_point, inplace=False):
"""
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
Args:
input (:obj:`torch.Tensor`):
Single-precision input tensor to be quantized.
scale (:obj:`torch.Tensor`):
Scaling factor for quantization.
zero_pint (:obj:`torch.Tensor`):
Shift for quantization.
inplace (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to compute inplace or not.
Returns:
:obj:`torch.Tensor`: Linearly quantized value of `input` according to `scale` and `zero_point`.
"""
# reshape scale and zeropoint for convolutional weights and activation
if len(input.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
zero_point = zero_point.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(input.shape) == 2:
scale = scale.view(-1, 1)
zero_point = zero_point.view(-1, 1)
else:
scale = scale.view(-1)
zero_point = zero_point.view(-1)
# quantized = float / scale + zero_point
if inplace:
input.mul_(1.0 / scale).add_(zero_point).round_()
return input
return torch.round(1.0 / scale * input + zero_point)
def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
"""
Compute the scaling factor with the given quantization range for symmetric quantization.
Args:
saturation_min (:obj:`torch.Tensor`):
Lower bound for quantization range.
saturation_max (:obj:`torch.Tensor`):
Upper bound for quantization range.
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to or not use channel-wise quantization.
Returns:
:obj:`torch.Tensor`: Scaling factor that linearly quantizes the given range between `saturation_min` and
`saturation_max`.
"""
# in this part, we do not need any gradient computation,
# in order to enfore this, we put torch.no_grad()
with torch.no_grad():
n = 2 ** (num_bits - 1) - 1
if per_channel:
scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
scale = torch.clamp(scale, min=1e-8) / n
else:
scale = max(saturation_min.abs(), saturation_max.abs())
scale = torch.clamp(scale, min=1e-8) / n
return scale
class SymmetricQuantFunction(Function):
"""
Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
"""
@staticmethod
def forward(ctx, x, k, percentile_mode, scale):
"""
Args:
x (:obj:`torch.Tensor`):
Floating point tensor to be quantized.
k (:obj:`int`):
Quantization bitwidth.
percentile_mode (:obj:`bool`):
Whether or not to use percentile calibration.
scale (:obj:`torch.Tensor`):
Pre-calculated scaling factor for `x`. Note that the current implementation of SymmetricQuantFunction
requires pre-calculated scaling factor.
Returns:
:obj:`torch.Tensor`: Symmetric-quantized value of `input`.
"""
zero_point = torch.tensor(0.0).to(scale.device)
n = 2 ** (k - 1) - 1
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
ctx.scale = scale
return new_quant_x
@staticmethod
def backward(ctx, grad_output):
scale = ctx.scale
if len(grad_output.shape) == 4:
scale = scale.view(-1, 1, 1, 1)
# reshape scale and zeropoint for linear weights
elif len(grad_output.shape) == 2:
scale = scale.view(-1, 1)
else:
scale = scale.view(-1)
return grad_output.clone() / scale, None, None, None, None
class floor_ste(Function):
"""
Straight-through Estimator(STE) for torch.floor()
"""
@staticmethod
def forward(ctx, x):
return torch.floor(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
class round_ste(Function):
"""
Straight-through Estimator(STE) for torch.round()
"""
@staticmethod
def forward(ctx, x):
return torch.round(x)
@staticmethod
def backward(ctx, grad_output):
return grad_output.clone()
def batch_frexp(inputs, max_bit=31):
"""
Decompose the scaling factor into mantissa and twos exponent.
Args:
scaling_factor (:obj:`torch.Tensor`):
Target scaling factor to decompose.
Returns:
:obj:``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
"""
shape_of_input = inputs.size()
# trans the input to be a 1-d tensor
inputs = inputs.view(-1)
output_m, output_e = np.frexp(inputs.cpu().numpy())
tmp_m = []
for m in output_m:
int_m_shifted = int(
decimal.Decimal(m * (2 ** max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
)
tmp_m.append(int_m_shifted)
output_m = np.array(tmp_m)
output_e = float(max_bit) - output_e
return (
torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
)
class FixedPointMul(Function):
"""
Function to perform fixed-point arthmetic that can match integer arthmetic on hardware.
Args:
pre_act (:obj:`torch.Tensor`):
Input tensor.
pre_act_scaling_factor (:obj:`torch.Tensor`):
Scaling factor of the input tensor `pre_act`.
bit_num (:obj:`int`):
Quantization bitwidth.
z_scaling_factor (:obj:`torch.Tensor`):
Scaling factor of the output tensor.
identity (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
Identity tensor, if exists.
identity_scaling_factor (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
Scaling factor of the identity tensor `identity`, if exists.
Returns:
:obj:`torch.Tensor`: Output tensor(`pre_act` if `identity` is not given, otherwise the addition of `pre_act`
and `identity`), whose scale is rescaled to `z_scaling_factor`.
"""
@staticmethod
def forward(
ctx,
pre_act,
pre_act_scaling_factor,
bit_num,
z_scaling_factor,
identity=None,
identity_scaling_factor=None,
):
if len(pre_act_scaling_factor.shape) == 3:
reshape = lambda x: x # noqa: E731
else:
reshape = lambda x: x.view(1, 1, -1) # noqa: E731
ctx.identity = identity
n = 2 ** (bit_num - 1) - 1
with torch.no_grad():
pre_act_scaling_factor = reshape(pre_act_scaling_factor)
if identity is not None:
identity_scaling_factor = reshape(identity_scaling_factor)
ctx.z_scaling_factor = z_scaling_factor
z_int = torch.round(pre_act / pre_act_scaling_factor)
_A = pre_act_scaling_factor.type(torch.double)
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
new_scale = _A / _B
new_scale = reshape(new_scale)
m, e = batch_frexp(new_scale)
output = z_int.type(torch.double) * m.type(torch.double)
output = torch.round(output / (2.0 ** e))
if identity is not None:
# needs addition of identity activation
wx_int = torch.round(identity / identity_scaling_factor)
_A = identity_scaling_factor.type(torch.double)
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
new_scale = _A / _B
new_scale = reshape(new_scale)
m1, e1 = batch_frexp(new_scale)
output1 = wx_int.type(torch.double) * m1.type(torch.double)
output1 = torch.round(output1 / (2.0 ** e1))
output = output1 + output
return torch.clamp(output.type(torch.float), -n - 1, n)
@staticmethod
def backward(ctx, grad_output):
identity_grad = None
if ctx.identity is not None:
identity_grad = grad_output.clone() / ctx.z_scaling_factor
return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None

View File

@ -1349,6 +1349,63 @@ def load_tf_weights_in_gpt2(*args, **kwargs):
requires_pytorch(load_tf_weights_in_gpt2)
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class IBertForMaskedLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class IBertForMultipleChoice:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class IBertForQuestionAnswering:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class IBertForSequenceClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class IBertForTokenClassification:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class IBertModel:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None

696
tests/test_modeling_ibert.py Executable file
View File

@ -0,0 +1,696 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
if is_torch_available():
import torch
import torch.nn as nn
from transformers import (
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
IBertConfig,
IBertForMaskedLM,
IBertForMultipleChoice,
IBertForQuestionAnswering,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertModel,
)
from transformers.models.ibert.modeling_ibert import (
IBertEmbeddings,
IntGELU,
IntLayerNorm,
IntSoftmax,
QuantAct,
QuantEmbedding,
QuantLinear,
create_position_ids_from_input_ids,
)
class IBertModelTester:
def __init__(
self,
parent,
):
self.parent = parent
self.batch_size = 13
self.seq_length = 7
self.is_training = True
self.use_input_mask = True
self.use_token_type_ids = True
self.use_labels = True
self.vocab_size = 99
self.hidden_size = 32
self.num_hidden_layers = 5
self.num_attention_heads = 4
self.intermediate_size = 37
self.hidden_act = "gelu"
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 512
self.type_vocab_size = 16
self.type_sequence_label_size = 2
self.initializer_range = 0.02
self.num_labels = 3
self.num_choices = 4
self.scope = None
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
token_type_ids = None
if self.use_token_type_ids:
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
sequence_labels = None
token_labels = None
choice_labels = None
if self.use_labels:
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
choice_labels = ids_tensor([self.batch_size], self.num_choices)
config = IBertConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
hidden_act=self.hidden_act,
hidden_dropout_prob=self.hidden_dropout_prob,
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range,
quant_mode=True,
)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = IBertModel(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
result = model(input_ids, token_type_ids=token_type_ids)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = IBertForMaskedLM(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = IBertForTokenClassification(config=config)
model.to(torch_device)
model.eval()
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = IBertForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
result = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = IBertForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
result = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class IBertModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False
test_torchscript = False
test_head_masking = False
test_resize_embeddings = False
all_model_classes = (
(
IBertForMaskedLM,
IBertModel,
IBertForSequenceClassification,
IBertForTokenClassification,
IBertForMultipleChoice,
IBertForQuestionAnswering,
)
if is_torch_available()
else ()
)
def setUp(self):
self.model_tester = IBertModelTester(self)
self.config_tester = ConfigTester(self, config_class=IBertConfig, hidden_size=37)
def test_config(self):
self.config_tester.run_common_tests()
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# I-BERT only supports absolute embedding
for type in ["absolute"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in IBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = IBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
def test_create_position_ids_respects_padding_index(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is IBertEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
model = IBertEmbeddings(config=config)
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
expected_positions = torch.as_tensor(
[[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
)
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
def test_create_position_ids_from_inputs_embeds(self):
"""Ensure that the default position ids only assign a sequential . This is a regression
test for https://github.com/huggingface/transformers/issues/1761
The position ids should be masked with the embedding object's padding index. Therefore, the
first available non-padding position index is IBertEmbeddings.padding_idx + 1
"""
config = self.model_tester.prepare_config_and_inputs()[0]
embeddings = IBertEmbeddings(config=config)
inputs_embeds = torch.Tensor(2, 4, 30)
expected_single_positions = [
0 + embeddings.padding_idx + 1,
1 + embeddings.padding_idx + 1,
2 + embeddings.padding_idx + 1,
3 + embeddings.padding_idx + 1,
]
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
self.assertEqual(position_ids.shape, expected_positions.shape)
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
# Override
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), QuantEmbedding)
model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
# Override
def test_feed_forward_chunking(self):
pass # I-BERT does not support chunking
# Override
def test_inputs_embeds(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
model.to(torch_device)
model.eval()
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
if not self.is_encoder_decoder:
input_ids = inputs["input_ids"]
del inputs["input_ids"]
else:
encoder_input_ids = inputs["input_ids"]
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
del inputs["input_ids"]
inputs.pop("decoder_input_ids", None)
wte = model.get_input_embeddings()
if not self.is_encoder_decoder:
embed, embed_scaling_factor = wte(input_ids)
inputs["inputs_embeds"] = embed
else:
inputs["inputs_embeds"] = wte(encoder_input_ids)
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
with torch.no_grad():
model(**inputs)[0]
@require_torch
class IBertModelIntegrationTest(unittest.TestCase):
def test_quant_embedding(self):
weight_bit = 8
embedding = QuantEmbedding(2, 4, quant_mode=True, weight_bit=weight_bit)
embedding_weight = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
embedding.weight = torch.nn.Parameter(embedding_weight)
expected_scaling_factor = embedding_weight.abs().max() / (2 ** (weight_bit - 1) - 1)
x, x_scaling_factor = embedding(torch.tensor(0))
y, y_scaling_factor = embedding(torch.tensor(1))
# scaling factor should follow the symmetric quantization rule
self.assertTrue(torch.allclose(x_scaling_factor, expected_scaling_factor, atol=1e-4))
self.assertTrue(torch.allclose(x_scaling_factor, expected_scaling_factor, atol=1e-4))
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
# quantization error should not exceed the scaling factor
self.assertTrue(torch.allclose(x, embedding_weight[0], atol=expected_scaling_factor))
self.assertTrue(torch.allclose(y, embedding_weight[1], atol=expected_scaling_factor))
def test_quant_act(self):
def _test_range():
act = QuantAct(activation_bit, act_range_momentum, quant_mode=True)
# First pass
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
x_scaling_factor = torch.tensor(1.0)
y, y_scaling_factor = act(x, x_scaling_factor)
y_int = y / y_scaling_factor
# After the first pass, x_min and x_max should be initialized with x.min() and x.max()
expected_x_min, expected_x_max = x.min(), x.max()
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
# scaling factor should follow the symmetric quantization rule
expected_range = torch.max(expected_x_min.abs(), expected_x_max.abs())
expected_scaling_factor = expected_range / (2 ** (activation_bit - 1) - 1)
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
# quantization error should not exceed the scaling factor
self.assertTrue(torch.allclose(x, y, atol=expected_scaling_factor))
# output should be integer
self.assertTrue(torch.allclose(y_int, y_int.round(), atol=1e-4))
# Second Pass
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]]) * 2
x_scaling_factor = torch.tensor(1.0)
y, y_scaling_factor = act(x, x_scaling_factor)
y_int = y / y_scaling_factor
# From the second pass, x_min and x_max should be updated with moving average
expected_x_min = expected_x_min * act_range_momentum + x.min() * (1 - act_range_momentum)
expected_x_max = expected_x_max * act_range_momentum + x.max() * (1 - act_range_momentum)
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
# scaling factor should follow the symmetric quantization rule
expected_range = torch.max(expected_x_min.abs(), expected_x_max.abs())
expected_scaling_factor = expected_range / (2 ** (activation_bit - 1) - 1)
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
# quantization error should not exceed the scaling factor
x = x.clamp(min=-expected_range, max=expected_range)
self.assertTrue(torch.allclose(x, y, atol=expected_scaling_factor))
# output should be integer
self.assertTrue(torch.allclose(y_int, y_int.round(), atol=1e-4))
# Third pass, with eval()
act.eval()
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]]) * 3
# In eval mode, min/max and scaling factor must be fixed
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
def _test_identity():
# test if identity and identity_scaling_factor are given
# should add the input values
act = QuantAct(activation_bit, act_range_momentum, quant_mode=True)
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
y = torch.tensor([[6.0, -7.0, 1.0, -2.0], [3.0, -4.0, -8.0, 5.0]])
x_scaling_factor = torch.tensor(1.0)
y_scaling_factor = torch.tensor(0.5)
z, z_scaling_factor = act(x, x_scaling_factor, y, y_scaling_factor)
z_int = z / z_scaling_factor
self.assertTrue(torch.allclose(x + y, z, atol=0.1))
self.assertTrue(torch.allclose(z_int, z_int.round(), atol=1e-4))
activation_bit = 8
act_range_momentum = 0.95
_test_range()
_test_identity()
def test_quant_linear(self):
def _test(per_channel):
linear_q = QuantLinear(2, 4, quant_mode=True, per_channel=per_channel, weight_bit=weight_bit)
linear_dq = QuantLinear(2, 4, quant_mode=False, per_channel=per_channel, weight_bit=weight_bit)
linear_weight = torch.tensor([[-1.0, 2.0, 3.0, -4.0], [5.0, -6.0, -7.0, 8.0]]).T
linear_q.weight = torch.nn.Parameter(linear_weight)
linear_dq.weight = torch.nn.Parameter(linear_weight)
q, q_scaling_factor = linear_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
dq, dq_scaling_factor = linear_dq(x, x_scaling_factor)
if per_channel:
q_max = linear_weight.abs().max(dim=1).values
else:
q_max = linear_weight.abs().max()
expected_scaling_factor = q_max / (2 ** (weight_bit - 1) - 1)
# scaling factor should follow the symmetric quantization rule
self.assertTrue(torch.allclose(linear_q.fc_scaling_factor, expected_scaling_factor, atol=1e-4))
# output of the normal linear layer and the quantized linear layer should be similar
self.assertTrue(torch.allclose(q, dq, atol=0.5))
# output of the quantized linear layer should be integer
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
weight_bit = 8
x = torch.tensor([[2.0, -5.0], [-3.0, 4.0]])
x_scaling_factor = torch.tensor([1.0])
_test(True)
_test(False)
def test_int_gelu(self):
gelu_q = IntGELU(quant_mode=True)
gelu_dq = torch.nn.GELU()
x_int = torch.range(-10000, 10000, 1)
x_scaling_factor = torch.tensor(0.001)
x = x_int * x_scaling_factor
q, q_scaling_factor = gelu_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
dq = gelu_dq(x)
# output of the normal GELU and the quantized GELU should be similar
self.assertTrue(torch.allclose(q, dq, atol=0.5))
# output of the quantized GELU layer should be integer
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
def test_force_dequant_gelu(self):
x_int = torch.range(-10000, 10000, 1)
x_scaling_factor = torch.tensor(0.001)
x = x_int * x_scaling_factor
gelu_dq = IntGELU(quant_mode=False)
gelu_fdqs_dict = {
True: [
IntGELU(quant_mode=True, force_dequant="nonlinear"),
IntGELU(quant_mode=True, force_dequant="gelu"),
],
False: [
IntGELU(quant_mode=True, force_dequant="none"),
IntGELU(quant_mode=True, force_dequant="softmax"),
IntGELU(quant_mode=True, force_dequant="layernorm"),
],
}
dq, dq_scaling_factor = gelu_dq(x, x_scaling_factor)
for label, gelu_fdqs in gelu_fdqs_dict.items():
for gelu_fdq in gelu_fdqs:
q, q_scaling_factor = gelu_fdq(x, x_scaling_factor)
if label:
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
else:
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
def test_int_softmax(self):
output_bit = 8
softmax_q = IntSoftmax(output_bit, quant_mode=True)
softmax_dq = torch.nn.Softmax()
# x_int = torch.range(-10000, 10000, 1)
def _test(array):
x_int = torch.tensor(array)
x_scaling_factor = torch.tensor(0.1)
x = x_int * x_scaling_factor
q, q_scaling_factor = softmax_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
dq = softmax_dq(x)
# output of the normal Softmax and the quantized Softmax should be similar
self.assertTrue(torch.allclose(q, dq, atol=0.5))
# output of the quantized GELU layer should be integer
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
# Output of the quantize Softmax should not exceed the output_bit
self.assertTrue(q.abs().max() < 2 ** output_bit)
array = [[i + j for j in range(10)] for i in range(-10, 10)]
_test(array)
array = [[i + j for j in range(50)] for i in range(-10, 10)]
_test(array)
array = [[i + 100 * j for j in range(2)] for i in range(-10, 10)]
_test(array)
def test_force_dequant_softmax(self):
output_bit = 8
array = [[i + j for j in range(10)] for i in range(-10, 10)]
x_int = torch.tensor(array)
x_scaling_factor = torch.tensor(0.1)
x = x_int * x_scaling_factor
softmax_dq = IntSoftmax(output_bit, quant_mode=False)
softmax_fdqs_dict = {
True: [
IntSoftmax(output_bit, quant_mode=True, force_dequant="nonlinear"),
IntSoftmax(output_bit, quant_mode=True, force_dequant="softmax"),
],
False: [
IntSoftmax(output_bit, quant_mode=True, force_dequant="none"),
IntSoftmax(output_bit, quant_mode=True, force_dequant="gelu"),
IntSoftmax(output_bit, quant_mode=True, force_dequant="layernorm"),
],
}
dq, dq_scaling_factor = softmax_dq(x, x_scaling_factor)
for label, softmax_fdqs in softmax_fdqs_dict.items():
for softmax_fdq in softmax_fdqs:
q, q_scaling_factor = softmax_fdq(x, x_scaling_factor)
if label:
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
else:
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
def test_int_layernorm(self):
output_bit = 8
# some random matrix
array = [[[i * j * j + j for j in range(5, 15)]] for i in range(-10, 10)]
x_int = torch.tensor(array)
x_scaling_factor = torch.tensor(0.1)
x = x_int * x_scaling_factor
ln_q = IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit)
ln_dq = torch.nn.LayerNorm(x.shape[1:], 1e-5)
ln_q.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_q.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
q, q_scaling_factor = ln_q(x, x_scaling_factor)
q_int = q / q_scaling_factor
dq = ln_dq(x)
# output of the normal LN and the quantized LN should be similar
self.assertTrue(torch.allclose(q, dq, atol=0.5))
# output of the quantized GELU layer should be integer
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
def test_force_dequant_layernorm(self):
output_bit = 8
array = [[[i * j * j + j for j in range(5, 15)]] for i in range(-10, 10)]
x_int = torch.tensor(array)
x_scaling_factor = torch.tensor(0.1)
x = x_int * x_scaling_factor
ln_dq = IntLayerNorm(x.shape[1:], 1e-5, quant_mode=False, output_bit=output_bit)
ln_fdqs_dict = {
True: [
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="nonlinear"),
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="layernorm"),
],
False: [
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="none"),
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="gelu"),
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="softmax"),
],
}
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
dq, dq_scaling_factor = ln_dq(x, x_scaling_factor)
for label, ln_fdqs in ln_fdqs_dict.items():
for ln_fdq in ln_fdqs:
ln_fdq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
ln_fdq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
q, q_scaling_factor = ln_fdq(x, x_scaling_factor)
if label:
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
else:
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
def quantize(self, model):
# Helper function that quantizes the given model
# Recursively convert all the `quant_mode` attributes as `True`
if hasattr(model, "quant_mode"):
model.quant_mode = True
elif type(model) == nn.Sequential:
for n, m in model.named_children():
self.quantize(m)
elif type(model) == nn.ModuleList:
for n in model:
self.quantize(n)
else:
for attr in dir(model):
mod = getattr(model, attr)
if isinstance(mod, nn.Module) and mod != model:
self.quantize(mod)
@slow
def test_inference_masked_lm(self):
# I-BERT should be "equivalent" to RoBERTa if not quantized
# Test coped from `test_modeling_roberta.py`
model = IBertForMaskedLM.from_pretrained("kssteven/ibert-roberta-base")
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[[33.8802, -4.3103, 22.7761], [4.6539, -2.8098, 13.6253], [1.8228, -3.6898, 8.8600]]]
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
# I-BERT should be "similar" to RoBERTa if quantized
self.quantize(model)
output = model(input_ids)[0]
self.assertEqual(output.shape, expected_shape)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=0.1))
@slow
def test_inference_classification_head(self):
# I-BERT should be "equivalent" to RoBERTa if not quantized
# Test coped from `test_modeling_roberta.py`
model = IBertForSequenceClassification.from_pretrained("kssteven/ibert-roberta-large-mnli")
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 3))
self.assertEqual(output.shape, expected_shape)
expected_tensor = torch.tensor([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
# I-BERT should be "similar" to RoBERTa if quantized
self.quantize(model)
output = model(input_ids)[0]
self.assertEqual(output.shape, expected_shape)
self.assertTrue(torch.allclose(output, expected_tensor, atol=0.1))