I-BERT model support (#10153)
* IBertConfig, IBertTokentizer added * IBert Model names moified * tokenizer bugfix * embedding -> QuantEmbedding * quant utils added * quant_mode added to configuration * QuantAct added, Embedding layer + QuantAct addition * QuantAct added * unused path removed, QKV quantized * self attention layer all quantized, except softmax * temporarl commit * all liner layers quantized * quant_utils bugfix * bugfix: requantization missing * IntGELU added * IntSoftmax added * LayerNorm implemented * LayerNorm implemented all * names changed: roberta->ibert * config not inherit from ROberta * No support for CausalLM * static quantization added, quantize_model.py removed * import modules uncommented * copyrights fixed * minor bugfix * quant_modules, quant_utils merged as one file * import * fixed * unused runfile removed * make style run * configutration.py docstring fixed * refactoring: comments removed, function name fixed * unused dependency removed * typo fixed * comments(Copied from), assertion string added * refactoring: super(..) -> super(), etc. * refactoring * refarctoring * make style * refactoring * cuda -> to(x.device) * weight initialization removed * QuantLinear set_param removed * QuantEmbedding set_param removed * IntLayerNorm set_param removed * assert string added * assertion error message fixed * is_decoder removed * enc-dec arguments/functions removed * Converter removed * quant_modules docstring fixed * conver_slow_tokenizer rolled back * quant_utils docstring fixed * unused aruments e.g. use_cache removed from config * weight initialization condition fixed * x_min, x_max initialized with small values to avoid div-zero exceptions * testing code for ibert * test emb, linear, gelu, softmax added * test ln and act added * style reformatted * force_dequant added * error tests overrided * make style * Style + Docs * force dequant tests added * Fix fast tokenizer in init * Fix doc * Remove space * docstring, IBertConfig, chunk_size * test_modeling_ibert refactoring * quant_modules.py refactoring * e2e integration test added * tokenizers removed * IBertConfig added to tokenizer_auto.py * bugfix * fix docs & test * fix style num 2 * final fixes Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu> Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr> Co-authored-by: Sylvain Gugger <sylvain.gugger@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
cb38ffcc5e
commit
63645b3b11
|
@ -263,6 +263,8 @@ TensorFlow and/or Flax.
|
|||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| Funnel Transformer | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| I-BERT | ❌ | ❌ | ✅ | ❌ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LED | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
|
||||
| LXMERT | ✅ | ✅ | ✅ | ✅ | ❌ |
|
||||
|
@ -405,6 +407,7 @@ TensorFlow and/or Flax.
|
|||
model_doc/fsmt
|
||||
model_doc/funnel
|
||||
model_doc/herbert
|
||||
model_doc/ibert
|
||||
model_doc/layoutlm
|
||||
model_doc/led
|
||||
model_doc/longformer
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
..
|
||||
Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
|
||||
I-BERT
|
||||
-----------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Overview
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
The I-BERT model was proposed in `I-BERT: Integer-only BERT Quantization <https://arxiv.org/abs/2006.10220>`__ by
|
||||
Sehoon Kim, Amir Gholami, Zhewei Yao, Michael W. Mahoney and Kurt Keutzer. It's a quantized version of RoBERTa running
|
||||
inference up to four times faster.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
*Transformer based models, like BERT and RoBERTa, have achieved state-of-the-art results in many Natural Language
|
||||
Processing tasks. However, their memory footprint, inference latency, and power consumption are prohibitive for
|
||||
efficient inference at the edge, and even at the data center. While quantization can be a viable solution for this,
|
||||
previous work on quantizing Transformer based models use floating-point arithmetic during inference, which cannot
|
||||
efficiently utilize integer-only logical units such as the recent Turing Tensor Cores, or traditional integer-only ARM
|
||||
processors. In this work, we propose I-BERT, a novel quantization scheme for Transformer based models that quantizes
|
||||
the entire inference with integer-only arithmetic. Based on lightweight integer-only approximation methods for
|
||||
nonlinear operations, e.g., GELU, Softmax, and Layer Normalization, I-BERT performs an end-to-end integer-only BERT
|
||||
inference without any floating point calculation. We evaluate our approach on GLUE downstream tasks using
|
||||
RoBERTa-Base/Large. We show that for both cases, I-BERT achieves similar (and slightly higher) accuracy as compared to
|
||||
the full-precision baseline. Furthermore, our preliminary implementation of I-BERT shows a speedup of 2.4 - 4.0x for
|
||||
INT8 inference on a T4 GPU system as compared to FP32 inference. The framework has been developed in PyTorch and has
|
||||
been open-sourced.*
|
||||
|
||||
|
||||
The original code can be found `here <https://github.com/kssteven418/I-BERT>`__.
|
||||
|
||||
IBertConfig
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertConfig
|
||||
:members:
|
||||
|
||||
|
||||
IBertModel
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertModel
|
||||
:members: forward
|
||||
|
||||
|
||||
IBertForMaskedLM
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertForMaskedLM
|
||||
:members: forward
|
||||
|
||||
|
||||
IBertForSequenceClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertForSequenceClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
IBertForMultipleChoice
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertForMultipleChoice
|
||||
:members: forward
|
||||
|
||||
|
||||
IBertForTokenClassification
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertForTokenClassification
|
||||
:members: forward
|
||||
|
||||
|
||||
IBertForQuestionAnswering
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.IBertForQuestionAnswering
|
||||
:members: forward
|
|
@ -182,6 +182,7 @@ _import_structure = {
|
|||
"models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"],
|
||||
"models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"],
|
||||
"models.herbert": ["HerbertTokenizer"],
|
||||
"models.ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||
"models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"],
|
||||
"models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"],
|
||||
"models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"],
|
||||
|
@ -613,6 +614,20 @@ if is_torch_available():
|
|||
"load_tf_weights_in_gpt2",
|
||||
]
|
||||
)
|
||||
_import_structure["models.ibert"].extend(
|
||||
[
|
||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"IBertForMaskedLM",
|
||||
"IBertForMultipleChoice",
|
||||
"IBertForQuestionAnswering",
|
||||
"IBertForSequenceClassification",
|
||||
"IBertForTokenClassification",
|
||||
"IBertLayer",
|
||||
"IBertModel",
|
||||
"IBertPreTrainedModel",
|
||||
"load_tf_weights_in_ibert",
|
||||
]
|
||||
)
|
||||
_import_structure["models.layoutlm"].extend(
|
||||
[
|
||||
"LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
|
@ -1328,6 +1343,7 @@ if TYPE_CHECKING:
|
|||
from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer
|
||||
from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer
|
||||
from .models.herbert import HerbertTokenizer
|
||||
from .models.ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer
|
||||
from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer
|
||||
from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer
|
||||
|
@ -1710,6 +1726,15 @@ if TYPE_CHECKING:
|
|||
GPT2PreTrainedModel,
|
||||
load_tf_weights_in_gpt2,
|
||||
)
|
||||
from .models.ibert import (
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
IBertForQuestionAnswering,
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertModel,
|
||||
)
|
||||
from .models.layoutlm import (
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
LayoutLMForMaskedLM,
|
||||
|
|
|
@ -40,6 +40,7 @@ from ..flaubert.configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE
|
|||
from ..fsmt.configuration_fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig
|
||||
from ..funnel.configuration_funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig
|
||||
from ..gpt2.configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||
from ..ibert.configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
from ..layoutlm.configuration_layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig
|
||||
from ..led.configuration_led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig
|
||||
from ..longformer.configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
|
||||
|
@ -110,6 +111,7 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = dict(
|
|||
PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
@ -123,6 +125,7 @@ CONFIG_MAPPING = OrderedDict(
|
|||
("led", LEDConfig),
|
||||
("blenderbot-small", BlenderbotSmallConfig),
|
||||
("retribert", RetriBertConfig),
|
||||
("ibert", IBertConfig),
|
||||
("mt5", MT5Config),
|
||||
("t5", T5Config),
|
||||
("mobilebert", MobileBertConfig),
|
||||
|
@ -173,6 +176,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
|
|||
("led", "LED"),
|
||||
("blenderbot-small", "BlenderbotSmall"),
|
||||
("retribert", "RetriBERT"),
|
||||
("ibert", "I-BERT"),
|
||||
("t5", "T5"),
|
||||
("mobilebert", "MobileBERT"),
|
||||
("distilbert", "DistilBERT"),
|
||||
|
|
|
@ -129,6 +129,14 @@ from ..funnel.modeling_funnel import (
|
|||
FunnelModel,
|
||||
)
|
||||
from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
|
||||
from ..ibert.modeling_ibert import (
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
IBertForQuestionAnswering,
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertModel,
|
||||
)
|
||||
from ..layoutlm.modeling_layoutlm import (
|
||||
LayoutLMForMaskedLM,
|
||||
LayoutLMForSequenceClassification,
|
||||
|
@ -270,6 +278,7 @@ from .configuration_auto import (
|
|||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
IBertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
|
@ -347,6 +356,7 @@ MODEL_MAPPING = OrderedDict(
|
|||
(MPNetConfig, MPNetModel),
|
||||
(TapasConfig, TapasModel),
|
||||
(MarianConfig, MarianModel),
|
||||
(IBertConfig, IBertModel),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -379,6 +389,7 @@ MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
|
|||
(FunnelConfig, FunnelForPreTraining),
|
||||
(MPNetConfig, MPNetForMaskedLM),
|
||||
(TapasConfig, TapasForMaskedLM),
|
||||
(IBertConfig, IBertForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -418,6 +429,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
|
|||
(TapasConfig, TapasForMaskedLM),
|
||||
(DebertaConfig, DebertaForMaskedLM),
|
||||
(DebertaV2Config, DebertaV2ForMaskedLM),
|
||||
(IBertConfig, IBertForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -476,6 +488,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
|
|||
(TapasConfig, TapasForMaskedLM),
|
||||
(DebertaConfig, DebertaForMaskedLM),
|
||||
(DebertaV2Config, DebertaV2ForMaskedLM),
|
||||
(IBertConfig, IBertForMaskedLM),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -529,6 +542,7 @@ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING = OrderedDict(
|
|||
(TransfoXLConfig, TransfoXLForSequenceClassification),
|
||||
(MPNetConfig, MPNetForSequenceClassification),
|
||||
(TapasConfig, TapasForSequenceClassification),
|
||||
(IBertConfig, IBertForSequenceClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -558,6 +572,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict(
|
|||
(MPNetConfig, MPNetForQuestionAnswering),
|
||||
(DebertaConfig, DebertaForQuestionAnswering),
|
||||
(DebertaV2Config, DebertaV2ForQuestionAnswering),
|
||||
(IBertConfig, IBertForQuestionAnswering),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -591,6 +606,7 @@ MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING = OrderedDict(
|
|||
(MPNetConfig, MPNetForTokenClassification),
|
||||
(DebertaConfig, DebertaForTokenClassification),
|
||||
(DebertaV2Config, DebertaV2ForTokenClassification),
|
||||
(IBertConfig, IBertForTokenClassification),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -613,6 +629,7 @@ MODEL_FOR_MULTIPLE_CHOICE_MAPPING = OrderedDict(
|
|||
(FlaubertConfig, FlaubertForMultipleChoice),
|
||||
(FunnelConfig, FunnelForMultipleChoice),
|
||||
(MPNetConfig, MPNetForMultipleChoice),
|
||||
(IBertConfig, IBertForMultipleChoice),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ from .configuration_auto import (
|
|||
FSMTConfig,
|
||||
FunnelConfig,
|
||||
GPT2Config,
|
||||
IBertConfig,
|
||||
LayoutLMConfig,
|
||||
LEDConfig,
|
||||
LongformerConfig,
|
||||
|
@ -244,6 +245,7 @@ TOKENIZER_MAPPING = OrderedDict(
|
|||
(TapasConfig, (TapasTokenizer, None)),
|
||||
(LEDConfig, (LEDTokenizer, LEDTokenizerFast)),
|
||||
(ConvBertConfig, (ConvBertTokenizer, ConvBertTokenizerFast)),
|
||||
(IBertConfig, (RobertaTokenizer, RobertaTokenizerFast)),
|
||||
(Wav2Vec2Config, (Wav2Vec2CTCTokenizer, None)),
|
||||
]
|
||||
)
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...file_utils import _BaseLazyModule, is_tokenizers_available, is_torch_available
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["modeling_ibert"] = [
|
||||
"IBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"IBertForMaskedLM",
|
||||
"IBertForMultipleChoice",
|
||||
"IBertForQuestionAnswering",
|
||||
"IBertForSequenceClassification",
|
||||
"IBertForTokenClassification",
|
||||
"IBertModel",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_ibert import (
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
IBertForQuestionAnswering,
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertModel,
|
||||
)
|
||||
|
||||
else:
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
class _LazyModule(_BaseLazyModule):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
__file__ = globals()["__file__"]
|
||||
__path__ = [os.path.dirname(__file__)]
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, _import_structure)
|
|
@ -0,0 +1,125 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
|
||||
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
|
||||
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" I-BERT configuration """
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"kssteven/ibert-roberta-base": "https://huggingface.co/kssteven/ibert-roberta-base/resolve/main/config.json",
|
||||
"kssteven/ibert-roberta-large": "https://huggingface.co/kssteven/ibert-roberta-large/resolve/main/config.json",
|
||||
"kssteven/ibert-roberta-large-mnli": "https://huggingface.co/kssteven/ibert-roberta-large-mnli/resolve/main/config.json",
|
||||
}
|
||||
|
||||
|
||||
class IBertConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.IBertModel`. It is used to
|
||||
instantiate a I-BERT model according to the specified arguments,
|
||||
|
||||
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
|
||||
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
|
||||
|
||||
Args:
|
||||
vocab_size (:obj:`int`, `optional`, defaults to 30522):
|
||||
Vocabulary size of the I-BERT model. Defines the number of different tokens that can be represented by the
|
||||
:obj:`inputs_ids` passed when calling :class:`~transformers.IBertModel`
|
||||
hidden_size (:obj:`int`, `optional`, defaults to 768):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_hidden_layers (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (:obj:`int`, `optional`, defaults to 12):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (:obj:`int`, `optional`, defaults to 3072):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
||||
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
||||
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
|
||||
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The dropout ratio for the attention probabilities.
|
||||
max_position_embeddings (:obj:`int`, `optional`, defaults to 512):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
type_vocab_size (:obj:`int`, `optional`, defaults to 2):
|
||||
The vocabulary size of the :obj:`token_type_ids` passed when calling :class:`~transformers.IBertModel`
|
||||
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
|
||||
The epsilon used by the layer normalization layers.
|
||||
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
|
||||
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
|
||||
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
|
||||
:obj:`"relative_key"`, please refer to `Self-Attention with Relative Position Representations (Shaw et al.)
|
||||
<https://arxiv.org/abs/1803.02155>`__. For more information on :obj:`"relative_key_query"`, please refer to
|
||||
`Method 4` in `Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)
|
||||
<https://arxiv.org/abs/2009.13658>`__.
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to quantize the model or not.
|
||||
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
|
||||
Force dequantize specific nonlinear layer. Dequatized layers are then executed with full precision.
|
||||
:obj:`"none"`, :obj:`"gelu"`, :obj:`"softmax"`, :obj:`"layernorm"` and :obj:`"nonlinear"` are supported. As
|
||||
deafult, it is set as :obj:`"none"`, which does not dequantize any layers. Please specify :obj:`"gelu"`,
|
||||
:obj:`"softmax"`, or :obj:`"layernorm"` to dequantize GELU, Softmax, or LayerNorm, respectively.
|
||||
:obj:`"nonlinear"` will dequantize all nonlinear layers, i.e., GELU, Softmax, and LayerNorm.
|
||||
"""
|
||||
|
||||
model_type = "ibert"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=30522,
|
||||
hidden_size=768,
|
||||
num_hidden_layers=12,
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=512,
|
||||
type_vocab_size=2,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-12,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
position_embedding_type="absolute",
|
||||
quant_mode=False,
|
||||
force_dequant="none",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.initializer_range = initializer_range
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.position_embedding_type = position_embedding_type
|
||||
self.quant_mode = quant_mode
|
||||
self.force_dequant = force_dequant
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,829 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
|
||||
# Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
|
||||
# Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import decimal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class QuantEmbedding(nn.Module):
|
||||
"""
|
||||
Quantized version of :obj:`torch.nn.Embedding`. Adds quantization-specific arguments on top of
|
||||
:obj:`torch.nn.Embedding`.
|
||||
|
||||
Args:
|
||||
weight_bit (:obj:`int`, `optiona`l, defaults to :obj:`8`):
|
||||
Bitwidth for the quantized weight.
|
||||
momentum (:obj:`float`, `optional, defaults to :obj:`0.95`):
|
||||
Momentum for updating the activation quantization range.
|
||||
quant_mode (:obj:`bool`, `optional, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
padding_idx=None,
|
||||
max_norm=None,
|
||||
norm_type=2.0,
|
||||
scale_grad_by_freq=False,
|
||||
sparse=False,
|
||||
_weight=None,
|
||||
weight_bit=8,
|
||||
momentum=0.95,
|
||||
quant_mode=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_ = num_embeddings
|
||||
self.dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
|
||||
self.register_buffer("weight_scaling_factor", torch.zeros(1))
|
||||
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
|
||||
|
||||
self.weight_bit = weight_bit
|
||||
self.momentum = momentum
|
||||
self.quant_mode = quant_mode
|
||||
self.percentile_mode = False
|
||||
self.weight_function = SymmetricQuantFunction.apply
|
||||
|
||||
def forward(self, x, positions=None, incremental_state=None):
|
||||
if not self.quant_mode:
|
||||
return (
|
||||
F.embedding(
|
||||
x,
|
||||
self.weight,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
w = self.weight
|
||||
w_transform = w.data.detach()
|
||||
w_min = w_transform.min().expand(1)
|
||||
w_max = w_transform.max().expand(1)
|
||||
|
||||
self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
|
||||
self.weight_integer = self.weight_function(
|
||||
self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
|
||||
)
|
||||
|
||||
emb_int = F.embedding(
|
||||
x,
|
||||
self.weight_integer,
|
||||
self.padding_idx,
|
||||
self.max_norm,
|
||||
self.norm_type,
|
||||
self.scale_grad_by_freq,
|
||||
self.sparse,
|
||||
)
|
||||
return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
|
||||
|
||||
|
||||
class QuantAct(nn.Module):
|
||||
"""
|
||||
Quantizes the given activation.
|
||||
|
||||
Args:
|
||||
activation_bit (:obj:`int`):
|
||||
Bitwidth for the quantized activation.
|
||||
act_range_momentum (:obj:`float`, `optional`, defaults to :obj:`0.95`):
|
||||
Momentum for updating the activation quantization range.
|
||||
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to or not use channel-wise quantization.
|
||||
channel_len (:obj:`int`, `optional`, defaults to :obj:`None`):
|
||||
Specify the channel length when set the `per_channel` True.
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
"""
|
||||
|
||||
def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
|
||||
super().__init__()
|
||||
|
||||
self.activation_bit = activation_bit
|
||||
self.act_range_momentum = act_range_momentum
|
||||
self.quant_mode = quant_mode
|
||||
self.per_channel = per_channel
|
||||
self.percentile = False
|
||||
self.act_function = SymmetricQuantFunction.apply
|
||||
|
||||
if not self.per_channel:
|
||||
self.register_buffer("x_min", torch.zeros(1))
|
||||
self.register_buffer("x_max", torch.zeros(1))
|
||||
self.register_buffer("act_scaling_factor", torch.zeros(1))
|
||||
self.x_min -= 1e-5
|
||||
self.x_max += 1e-5
|
||||
else:
|
||||
raise NotImplementedError("per-channel mode is not currently supported for activation.")
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
"{0}(activation_bit={1}, "
|
||||
"quant_mode: {2}, Act_min: {3:.2f}, "
|
||||
"Act_max: {4:.2f})".format(
|
||||
self.__class__.__name__, self.activation_bit, self.quant_mode, self.x_min.item(), self.x_max.item()
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
pre_act_scaling_factor=None,
|
||||
identity=None,
|
||||
identity_scaling_factor=None,
|
||||
specified_min=None,
|
||||
specified_max=None,
|
||||
):
|
||||
|
||||
x_act = x if identity is None else identity + x
|
||||
# collect runnng stats if traiing
|
||||
if self.training:
|
||||
assert not self.percentile, "percentile mode is not currently supported for activation."
|
||||
assert not self.per_channel, "per-channel mode is not currently supported for activation."
|
||||
x_min = x_act.data.min()
|
||||
x_max = x_act.data.max()
|
||||
|
||||
assert (
|
||||
x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0
|
||||
), "NaN detected when computing min/max of the activation"
|
||||
|
||||
# Initialization
|
||||
if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
|
||||
self.x_min = self.x_min + x_min
|
||||
self.x_max = self.x_max + x_max
|
||||
|
||||
# exponential moving average (EMA)
|
||||
# use momentum to prevent the quantized values change greatly every iteration
|
||||
elif self.act_range_momentum == -1:
|
||||
self.x_min = torch.min(self.x_min, x_min)
|
||||
self.x_max = torch.max(self.x_max, x_max)
|
||||
else:
|
||||
self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
|
||||
self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
|
||||
|
||||
if not self.quant_mode:
|
||||
return x_act, None
|
||||
|
||||
x_min = self.x_min if specified_min is None else specified_min
|
||||
x_max = self.x_max if specified_max is None else specified_max
|
||||
|
||||
self.act_scaling_factor = symmetric_linear_quantization_params(
|
||||
self.activation_bit, x_min, x_max, per_channel=self.per_channel
|
||||
)
|
||||
|
||||
if pre_act_scaling_factor is None:
|
||||
# this is for the input quantization
|
||||
quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
|
||||
else:
|
||||
quant_act_int = FixedPointMul.apply(
|
||||
x,
|
||||
pre_act_scaling_factor,
|
||||
self.activation_bit,
|
||||
self.act_scaling_factor,
|
||||
identity,
|
||||
identity_scaling_factor,
|
||||
)
|
||||
|
||||
correct_output_scale = self.act_scaling_factor.view(-1)
|
||||
|
||||
return quant_act_int * correct_output_scale, self.act_scaling_factor
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
"""
|
||||
Quantized version of :obj:`torch.nn.Linear`. Adds quantization-specific arguments on top of :obj:`torch.nn.Linear`.
|
||||
|
||||
Args:
|
||||
weight_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
|
||||
Bitwidth for the quantized weight.
|
||||
bias_bit (:obj:`int`, `optional`, defaults to :obj:`32`):
|
||||
Bitwidth for the quantized bias.
|
||||
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to use channel-wise quantization.
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
|
||||
self.register_buffer("weight_integer", torch.zeros_like(self.weight))
|
||||
self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
self.register_buffer("bias_integer", torch.zeros_like(self.bias))
|
||||
|
||||
self.weight_bit = weight_bit
|
||||
self.quant_mode = quant_mode
|
||||
self.per_channel = per_channel
|
||||
self.bias_bit = bias_bit
|
||||
self.quant_mode = quant_mode
|
||||
self.percentile_mode = False
|
||||
self.weight_function = SymmetricQuantFunction.apply
|
||||
|
||||
def __repr__(self):
|
||||
s = super().__repr__()
|
||||
s = "(" + s + " weight_bit={}, quant_mode={})".format(self.weight_bit, self.quant_mode)
|
||||
return s
|
||||
|
||||
def forward(self, x, prev_act_scaling_factor=None):
|
||||
if not self.quant_mode:
|
||||
return F.linear(x, weight=self.weight, bias=self.bias), None
|
||||
|
||||
# assert that prev_act_scaling_factor is a scalar tensor
|
||||
assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
|
||||
"Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
|
||||
"Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
|
||||
)
|
||||
|
||||
w = self.weight
|
||||
w_transform = w.data.detach()
|
||||
if self.per_channel:
|
||||
w_min, _ = torch.min(w_transform, dim=1, out=None)
|
||||
w_max, _ = torch.max(w_transform, dim=1, out=None)
|
||||
else:
|
||||
w_min = w_transform.min().expand(1)
|
||||
w_max = w_transform.max().expand(1)
|
||||
|
||||
self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
|
||||
self.weight_integer = self.weight_function(
|
||||
self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
|
||||
)
|
||||
|
||||
bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
|
||||
|
||||
if self.bias is not None:
|
||||
self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
|
||||
|
||||
prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
|
||||
x_int = x / prev_act_scaling_factor
|
||||
|
||||
return (
|
||||
F.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
|
||||
bias_scaling_factor,
|
||||
)
|
||||
|
||||
|
||||
class IntGELU(nn.Module):
|
||||
"""
|
||||
Quantized version of :obj:`torch.nn.GELU`. Adds quantization-specific arguments on top of :obj:`torch.nn.GELU`.
|
||||
|
||||
Args:
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
|
||||
Force dequantize the layer if either "gelu" or "nonlinear" is given.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_mode=True, force_dequant="none"):
|
||||
super().__init__()
|
||||
self.quant_mode = quant_mode
|
||||
|
||||
if force_dequant in ["nonlinear", "gelu"]:
|
||||
logger.info("Force dequantize gelu")
|
||||
self.quant_mode = False
|
||||
|
||||
if not self.quant_mode:
|
||||
self.activation_fn = nn.GELU()
|
||||
|
||||
self.k = 1.4142
|
||||
self.const = 14 # dummy integer constant
|
||||
self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
|
||||
self.coeff[2] /= self.coeff[0]
|
||||
|
||||
def int_erf(self, x_int, scaling_factor):
|
||||
b_int = torch.floor(self.coeff[1] / scaling_factor)
|
||||
c_int = torch.floor(self.coeff[2] / scaling_factor ** 2)
|
||||
sign = torch.sign(x_int)
|
||||
|
||||
abs_int = torch.min(torch.abs(x_int), -b_int)
|
||||
y_int = sign * ((abs_int + b_int) ** 2 + c_int)
|
||||
scaling_factor = scaling_factor ** 2 * self.coeff[0]
|
||||
|
||||
# avoid overflow
|
||||
y_int = floor_ste.apply(y_int / 2 ** self.const)
|
||||
scaling_factor = scaling_factor * 2 ** self.const
|
||||
|
||||
return y_int, scaling_factor
|
||||
|
||||
def forward(self, x, scaling_factor=None):
|
||||
if not self.quant_mode:
|
||||
return self.activation_fn(x), None
|
||||
|
||||
x_int = x / scaling_factor
|
||||
sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
|
||||
|
||||
shift_int = 1.0 // sigmoid_scaling_factor
|
||||
|
||||
x_int = x_int * (sigmoid_int + shift_int)
|
||||
scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
|
||||
|
||||
return x_int * scaling_factor, scaling_factor
|
||||
|
||||
|
||||
class IntSoftmax(nn.Module):
|
||||
"""
|
||||
Quantized version of :obj:`torch.nn.Softmax`. Adds quantization-specific arguments on top of
|
||||
:obj:`torch.nn.Softmax`.
|
||||
|
||||
Args:
|
||||
output_bit (:obj:`int`):
|
||||
Bitwidth for the layer output activation.
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
|
||||
Force dequantize the layer if either "softmax" or "nonlinear" is given.
|
||||
"""
|
||||
|
||||
def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
|
||||
super().__init__()
|
||||
self.output_bit = output_bit
|
||||
self.max_bit = 32
|
||||
self.quant_mode = quant_mode
|
||||
|
||||
if force_dequant in ["nonlinear", "softmax"]:
|
||||
logger.info("Force dequantize softmax")
|
||||
self.quant_mode = False
|
||||
|
||||
self.act = QuantAct(16, quant_mode=self.quant_mode)
|
||||
self.x0 = -0.6931 # -ln2
|
||||
self.const = 30 # dummy integer constant
|
||||
self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
|
||||
self.coef[1] /= self.coef[0]
|
||||
self.coef[2] /= self.coef[0]
|
||||
|
||||
def int_polynomial(self, x_int, scaling_factor):
|
||||
with torch.no_grad():
|
||||
b_int = torch.floor(self.coef[1] / scaling_factor)
|
||||
c_int = torch.floor(self.coef[2] / scaling_factor ** 2)
|
||||
z = (x_int + b_int) * x_int + c_int
|
||||
scaling_factor = self.coef[0] * scaling_factor ** 2
|
||||
return z, scaling_factor
|
||||
|
||||
def int_exp(self, x_int, scaling_factor):
|
||||
with torch.no_grad():
|
||||
x0_int = torch.floor(self.x0 / scaling_factor)
|
||||
x_int = torch.max(x_int, self.const * x0_int)
|
||||
|
||||
q = floor_ste.apply(x_int / x0_int)
|
||||
r = x_int - x0_int * q
|
||||
exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
|
||||
exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
|
||||
scaling_factor = exp_scaling_factor / 2 ** self.const
|
||||
return exp_int, scaling_factor
|
||||
|
||||
def forward(self, x, scaling_factor):
|
||||
if not self.quant_mode:
|
||||
return nn.Softmax(dim=-1)(x), None
|
||||
|
||||
x_int = x / scaling_factor
|
||||
|
||||
x_int_max, _ = x_int.max(dim=-1, keepdim=True)
|
||||
x_int = x_int - x_int_max
|
||||
exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
|
||||
|
||||
# Avoid overflow
|
||||
exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
|
||||
exp_int = exp / exp_scaling_factor
|
||||
|
||||
exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
|
||||
factor = floor_ste.apply(2 ** self.max_bit / exp_int_sum)
|
||||
exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
|
||||
scaling_factor = 1 / 2 ** self.output_bit
|
||||
return exp_int * scaling_factor, scaling_factor
|
||||
|
||||
|
||||
class IntLayerNorm(nn.Module):
|
||||
"""
|
||||
Quantized version of :obj:`torch.nn.LayerNorm`. Adds quantization-specific arguments on top of
|
||||
:obj:`torch.nn.LayerNorm`.
|
||||
|
||||
Args:
|
||||
output_bit (:obj:`int`, `optional`, defaults to :obj:`8`):
|
||||
Bitwidth for the layer output activation.
|
||||
quant_mode (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not the layer is quantized.
|
||||
force_dequant (:obj:`str`, `optional`, defaults to :obj:`"none"`):
|
||||
Force dequantize the layer if either "layernorm" or "nonlinear" is given.
|
||||
"""
|
||||
|
||||
def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
|
||||
super().__init__()
|
||||
self.normalized_shape = normalized_shape
|
||||
self.eps = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(normalized_shape))
|
||||
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
||||
|
||||
self.quant_mode = quant_mode
|
||||
if force_dequant in ["nonlinear", "layernorm"]:
|
||||
logger.info("Force dequantize layernorm")
|
||||
self.quant_mode = False
|
||||
|
||||
self.register_buffer("shift", torch.zeros(1))
|
||||
self.output_bit = output_bit
|
||||
self.max_bit = 32
|
||||
self.dim_sqrt = None
|
||||
self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
|
||||
|
||||
def set_shift(self, y_int):
|
||||
with torch.no_grad():
|
||||
y_sq_int = y_int ** 2
|
||||
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||||
shift = (torch.log2(torch.sqrt(var_int / 2 ** self.max_bit)).ceil()).max()
|
||||
shift_old = self.shift
|
||||
self.shift = torch.max(self.shift, shift)
|
||||
logger.info("Dynamic shift adjustment: {} -> {}".format(int(shift_old), int(self.shift)))
|
||||
|
||||
def overflow_fallback(self, y_int):
|
||||
"""
|
||||
This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
|
||||
to avoid overflow in the subsequent runs.
|
||||
"""
|
||||
self.set_shift(y_int) # adjusts `self.shift`
|
||||
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
|
||||
y_sq_int = y_int_shifted ** 2
|
||||
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||||
return var_int
|
||||
|
||||
def forward(self, x, scaling_factor=None):
|
||||
if not self.quant_mode:
|
||||
mean = x.mean(axis=2, keepdim=True)
|
||||
y = x - mean
|
||||
var = torch.mean(y ** 2, axis=2, keepdim=True)
|
||||
x = y / torch.sqrt(self.eps + var)
|
||||
x = x * self.weight + self.bias
|
||||
return x, None
|
||||
|
||||
# compute sqrt of the feature dimension if it is the first run
|
||||
if self.dim_sqrt is None:
|
||||
n = torch.tensor(x.shape[2], dtype=torch.float)
|
||||
self.dim_sqrt = torch.sqrt(n).to(x.device)
|
||||
|
||||
# Normalization: computes mean and variance(std)
|
||||
x_int = x / scaling_factor
|
||||
mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
|
||||
y_int = x_int - mean_int
|
||||
y_int_shifted = floor_ste.apply(y_int / 2 ** self.shift)
|
||||
y_sq_int = y_int_shifted ** 2
|
||||
var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
|
||||
|
||||
# overflow handling in training time
|
||||
if self.training:
|
||||
# if overflow is detected
|
||||
if var_int.max() >= 2 ** self.max_bit:
|
||||
var_int = self.overflow_fallback(y_int)
|
||||
assert var_int.max() < 2 ** self.max_bit + 0.1, (
|
||||
"Error detected in overflow handling: "
|
||||
"`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
|
||||
)
|
||||
|
||||
# To be replaced with integer-sqrt kernel that produces the same output
|
||||
std_int = floor_ste.apply(torch.sqrt(var_int)) * 2 ** self.shift
|
||||
factor = floor_ste.apply(2 ** 31 / std_int)
|
||||
y_int = floor_ste.apply(y_int * factor / 2)
|
||||
scaling_factor = self.dim_sqrt / 2 ** 30
|
||||
|
||||
# scaling and shifting
|
||||
bias = self.bias.data.detach() / (self.weight.data.detach())
|
||||
bias_int = floor_ste.apply(bias / scaling_factor)
|
||||
|
||||
y_int = y_int + bias_int
|
||||
scaling_factor = scaling_factor * self.weight
|
||||
x = y_int * scaling_factor
|
||||
|
||||
return x, scaling_factor
|
||||
|
||||
|
||||
def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
|
||||
"""
|
||||
Calculate the percentile max and min values in a given tensor
|
||||
|
||||
Args:
|
||||
input (:obj:`torch.Tensor`):
|
||||
The target tensor to calculate percentile max and min.
|
||||
lower_percentile (:obj:`float`):
|
||||
If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
|
||||
upper_percentile (:obj:`float`):
|
||||
If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
|
||||
output_tensor (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, this function returns tensors, otherwise it returns values.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of `input`
|
||||
"""
|
||||
input_length = input.shape[0]
|
||||
|
||||
lower_index = round(input_length * (1 - lower_percentile * 0.01))
|
||||
upper_index = round(input_length * upper_percentile * 0.01)
|
||||
|
||||
upper_bound = torch.kthvalue(input, k=upper_index).values
|
||||
|
||||
if lower_percentile == 0:
|
||||
lower_bound = upper_bound * 0
|
||||
# lower_index += 1
|
||||
else:
|
||||
lower_bound = -torch.kthvalue(-input, k=lower_index).values
|
||||
|
||||
if not output_tensor:
|
||||
lower_bound = lower_bound.item()
|
||||
upper_bound = upper_bound.item()
|
||||
return lower_bound, upper_bound
|
||||
|
||||
|
||||
def linear_quantize(input, scale, zero_point, inplace=False):
|
||||
"""
|
||||
Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
|
||||
|
||||
Args:
|
||||
input (:obj:`torch.Tensor`):
|
||||
Single-precision input tensor to be quantized.
|
||||
scale (:obj:`torch.Tensor`):
|
||||
Scaling factor for quantization.
|
||||
zero_pint (:obj:`torch.Tensor`):
|
||||
Shift for quantization.
|
||||
inplace (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to compute inplace or not.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: Linearly quantized value of `input` according to `scale` and `zero_point`.
|
||||
"""
|
||||
# reshape scale and zeropoint for convolutional weights and activation
|
||||
if len(input.shape) == 4:
|
||||
scale = scale.view(-1, 1, 1, 1)
|
||||
zero_point = zero_point.view(-1, 1, 1, 1)
|
||||
# reshape scale and zeropoint for linear weights
|
||||
elif len(input.shape) == 2:
|
||||
scale = scale.view(-1, 1)
|
||||
zero_point = zero_point.view(-1, 1)
|
||||
else:
|
||||
scale = scale.view(-1)
|
||||
zero_point = zero_point.view(-1)
|
||||
# quantized = float / scale + zero_point
|
||||
if inplace:
|
||||
input.mul_(1.0 / scale).add_(zero_point).round_()
|
||||
return input
|
||||
return torch.round(1.0 / scale * input + zero_point)
|
||||
|
||||
|
||||
def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
|
||||
"""
|
||||
Compute the scaling factor with the given quantization range for symmetric quantization.
|
||||
|
||||
Args:
|
||||
saturation_min (:obj:`torch.Tensor`):
|
||||
Lower bound for quantization range.
|
||||
saturation_max (:obj:`torch.Tensor`):
|
||||
Upper bound for quantization range.
|
||||
per_channel (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to or not use channel-wise quantization.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: Scaling factor that linearly quantizes the given range between `saturation_min` and
|
||||
`saturation_max`.
|
||||
"""
|
||||
# in this part, we do not need any gradient computation,
|
||||
# in order to enfore this, we put torch.no_grad()
|
||||
with torch.no_grad():
|
||||
n = 2 ** (num_bits - 1) - 1
|
||||
|
||||
if per_channel:
|
||||
scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
|
||||
scale = torch.clamp(scale, min=1e-8) / n
|
||||
|
||||
else:
|
||||
scale = max(saturation_min.abs(), saturation_max.abs())
|
||||
scale = torch.clamp(scale, min=1e-8) / n
|
||||
|
||||
return scale
|
||||
|
||||
|
||||
class SymmetricQuantFunction(Function):
|
||||
"""
|
||||
Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, k, percentile_mode, scale):
|
||||
"""
|
||||
Args:
|
||||
x (:obj:`torch.Tensor`):
|
||||
Floating point tensor to be quantized.
|
||||
k (:obj:`int`):
|
||||
Quantization bitwidth.
|
||||
percentile_mode (:obj:`bool`):
|
||||
Whether or not to use percentile calibration.
|
||||
scale (:obj:`torch.Tensor`):
|
||||
Pre-calculated scaling factor for `x`. Note that the current implementation of SymmetricQuantFunction
|
||||
requires pre-calculated scaling factor.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: Symmetric-quantized value of `input`.
|
||||
"""
|
||||
zero_point = torch.tensor(0.0).to(scale.device)
|
||||
|
||||
n = 2 ** (k - 1) - 1
|
||||
new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
|
||||
new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
|
||||
|
||||
ctx.scale = scale
|
||||
return new_quant_x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
|
||||
scale = ctx.scale
|
||||
if len(grad_output.shape) == 4:
|
||||
scale = scale.view(-1, 1, 1, 1)
|
||||
# reshape scale and zeropoint for linear weights
|
||||
elif len(grad_output.shape) == 2:
|
||||
scale = scale.view(-1, 1)
|
||||
else:
|
||||
scale = scale.view(-1)
|
||||
|
||||
return grad_output.clone() / scale, None, None, None, None
|
||||
|
||||
|
||||
class floor_ste(Function):
|
||||
"""
|
||||
Straight-through Estimator(STE) for torch.floor()
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return torch.floor(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.clone()
|
||||
|
||||
|
||||
class round_ste(Function):
|
||||
"""
|
||||
Straight-through Estimator(STE) for torch.round()
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x):
|
||||
return torch.round(x)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.clone()
|
||||
|
||||
|
||||
def batch_frexp(inputs, max_bit=31):
|
||||
"""
|
||||
Decompose the scaling factor into mantissa and twos exponent.
|
||||
|
||||
Args:
|
||||
scaling_factor (:obj:`torch.Tensor`):
|
||||
Target scaling factor to decompose.
|
||||
|
||||
Returns:
|
||||
:obj:``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
|
||||
"""
|
||||
|
||||
shape_of_input = inputs.size()
|
||||
|
||||
# trans the input to be a 1-d tensor
|
||||
inputs = inputs.view(-1)
|
||||
|
||||
output_m, output_e = np.frexp(inputs.cpu().numpy())
|
||||
tmp_m = []
|
||||
for m in output_m:
|
||||
int_m_shifted = int(
|
||||
decimal.Decimal(m * (2 ** max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
|
||||
)
|
||||
tmp_m.append(int_m_shifted)
|
||||
output_m = np.array(tmp_m)
|
||||
|
||||
output_e = float(max_bit) - output_e
|
||||
|
||||
return (
|
||||
torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
|
||||
torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
|
||||
)
|
||||
|
||||
|
||||
class FixedPointMul(Function):
|
||||
"""
|
||||
Function to perform fixed-point arthmetic that can match integer arthmetic on hardware.
|
||||
|
||||
Args:
|
||||
pre_act (:obj:`torch.Tensor`):
|
||||
Input tensor.
|
||||
pre_act_scaling_factor (:obj:`torch.Tensor`):
|
||||
Scaling factor of the input tensor `pre_act`.
|
||||
bit_num (:obj:`int`):
|
||||
Quantization bitwidth.
|
||||
z_scaling_factor (:obj:`torch.Tensor`):
|
||||
Scaling factor of the output tensor.
|
||||
identity (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
|
||||
Identity tensor, if exists.
|
||||
identity_scaling_factor (:obj:`torch.Tensor`, `optional`, defaults to :obj:`None`):
|
||||
Scaling factor of the identity tensor `identity`, if exists.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.Tensor`: Output tensor(`pre_act` if `identity` is not given, otherwise the addition of `pre_act`
|
||||
and `identity`), whose scale is rescaled to `z_scaling_factor`.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(
|
||||
ctx,
|
||||
pre_act,
|
||||
pre_act_scaling_factor,
|
||||
bit_num,
|
||||
z_scaling_factor,
|
||||
identity=None,
|
||||
identity_scaling_factor=None,
|
||||
):
|
||||
|
||||
if len(pre_act_scaling_factor.shape) == 3:
|
||||
reshape = lambda x: x # noqa: E731
|
||||
else:
|
||||
reshape = lambda x: x.view(1, 1, -1) # noqa: E731
|
||||
ctx.identity = identity
|
||||
|
||||
n = 2 ** (bit_num - 1) - 1
|
||||
|
||||
with torch.no_grad():
|
||||
pre_act_scaling_factor = reshape(pre_act_scaling_factor)
|
||||
if identity is not None:
|
||||
identity_scaling_factor = reshape(identity_scaling_factor)
|
||||
|
||||
ctx.z_scaling_factor = z_scaling_factor
|
||||
|
||||
z_int = torch.round(pre_act / pre_act_scaling_factor)
|
||||
_A = pre_act_scaling_factor.type(torch.double)
|
||||
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
|
||||
new_scale = _A / _B
|
||||
new_scale = reshape(new_scale)
|
||||
|
||||
m, e = batch_frexp(new_scale)
|
||||
|
||||
output = z_int.type(torch.double) * m.type(torch.double)
|
||||
output = torch.round(output / (2.0 ** e))
|
||||
|
||||
if identity is not None:
|
||||
# needs addition of identity activation
|
||||
wx_int = torch.round(identity / identity_scaling_factor)
|
||||
|
||||
_A = identity_scaling_factor.type(torch.double)
|
||||
_B = (z_scaling_factor.type(torch.float)).type(torch.double)
|
||||
new_scale = _A / _B
|
||||
new_scale = reshape(new_scale)
|
||||
|
||||
m1, e1 = batch_frexp(new_scale)
|
||||
output1 = wx_int.type(torch.double) * m1.type(torch.double)
|
||||
output1 = torch.round(output1 / (2.0 ** e1))
|
||||
|
||||
output = output1 + output
|
||||
|
||||
return torch.clamp(output.type(torch.float), -n - 1, n)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
identity_grad = None
|
||||
if ctx.identity is not None:
|
||||
identity_grad = grad_output.clone() / ctx.z_scaling_factor
|
||||
return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None
|
|
@ -1349,6 +1349,63 @@ def load_tf_weights_in_gpt2(*args, **kwargs):
|
|||
requires_pytorch(load_tf_weights_in_gpt2)
|
||||
|
||||
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
class IBertForMaskedLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertForMultipleChoice:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertForQuestionAnswering:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertForSequenceClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertForTokenClassification:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
class IBertModel:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
|
||||
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,696 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_torch, slow, torch_device
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from transformers import (
|
||||
IBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
IBertConfig,
|
||||
IBertForMaskedLM,
|
||||
IBertForMultipleChoice,
|
||||
IBertForQuestionAnswering,
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertModel,
|
||||
)
|
||||
from transformers.models.ibert.modeling_ibert import (
|
||||
IBertEmbeddings,
|
||||
IntGELU,
|
||||
IntLayerNorm,
|
||||
IntSoftmax,
|
||||
QuantAct,
|
||||
QuantEmbedding,
|
||||
QuantLinear,
|
||||
create_position_ids_from_input_ids,
|
||||
)
|
||||
|
||||
|
||||
class IBertModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = 13
|
||||
self.seq_length = 7
|
||||
self.is_training = True
|
||||
self.use_input_mask = True
|
||||
self.use_token_type_ids = True
|
||||
self.use_labels = True
|
||||
self.vocab_size = 99
|
||||
self.hidden_size = 32
|
||||
self.num_hidden_layers = 5
|
||||
self.num_attention_heads = 4
|
||||
self.intermediate_size = 37
|
||||
self.hidden_act = "gelu"
|
||||
self.hidden_dropout_prob = 0.1
|
||||
self.attention_probs_dropout_prob = 0.1
|
||||
self.max_position_embeddings = 512
|
||||
self.type_vocab_size = 16
|
||||
self.type_sequence_label_size = 2
|
||||
self.initializer_range = 0.02
|
||||
self.num_labels = 3
|
||||
self.num_choices = 4
|
||||
self.scope = None
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
|
||||
input_mask = None
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
token_type_ids = None
|
||||
if self.use_token_type_ids:
|
||||
token_type_ids = ids_tensor([self.batch_size, self.seq_length], self.type_vocab_size)
|
||||
|
||||
sequence_labels = None
|
||||
token_labels = None
|
||||
choice_labels = None
|
||||
if self.use_labels:
|
||||
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
|
||||
token_labels = ids_tensor([self.batch_size, self.seq_length], self.num_labels)
|
||||
choice_labels = ids_tensor([self.batch_size], self.num_choices)
|
||||
|
||||
config = IBertConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
num_attention_heads=self.num_attention_heads,
|
||||
intermediate_size=self.intermediate_size,
|
||||
hidden_act=self.hidden_act,
|
||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
type_vocab_size=self.type_vocab_size,
|
||||
initializer_range=self.initializer_range,
|
||||
quant_mode=True,
|
||||
)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = IBertModel(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
|
||||
result = model(input_ids, token_type_ids=token_type_ids)
|
||||
result = model(input_ids)
|
||||
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = IBertForMaskedLM(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = IBertForTokenClassification(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = IBertForMultipleChoice(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
result = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_choices))
|
||||
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = IBertForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
@require_torch
|
||||
class IBertModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
test_pruning = False
|
||||
test_torchscript = False
|
||||
test_head_masking = False
|
||||
test_resize_embeddings = False
|
||||
|
||||
all_model_classes = (
|
||||
(
|
||||
IBertForMaskedLM,
|
||||
IBertModel,
|
||||
IBertForSequenceClassification,
|
||||
IBertForTokenClassification,
|
||||
IBertForMultipleChoice,
|
||||
IBertForQuestionAnswering,
|
||||
)
|
||||
if is_torch_available()
|
||||
else ()
|
||||
)
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = IBertModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=IBertConfig, hidden_size=37)
|
||||
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
# I-BERT only supports absolute embedding
|
||||
for type in ["absolute"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in IBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = IBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_create_position_ids_respects_padding_index(self):
|
||||
"""Ensure that the default position ids only assign a sequential . This is a regression
|
||||
test for https://github.com/huggingface/transformers/issues/1761
|
||||
|
||||
The position ids should be masked with the embedding object's padding index. Therefore, the
|
||||
first available non-padding position index is IBertEmbeddings.padding_idx + 1
|
||||
"""
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
model = IBertEmbeddings(config=config)
|
||||
|
||||
input_ids = torch.as_tensor([[12, 31, 13, model.padding_idx]])
|
||||
expected_positions = torch.as_tensor(
|
||||
[[0 + model.padding_idx + 1, 1 + model.padding_idx + 1, 2 + model.padding_idx + 1, model.padding_idx]]
|
||||
)
|
||||
|
||||
position_ids = create_position_ids_from_input_ids(input_ids, model.padding_idx)
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
def test_create_position_ids_from_inputs_embeds(self):
|
||||
"""Ensure that the default position ids only assign a sequential . This is a regression
|
||||
test for https://github.com/huggingface/transformers/issues/1761
|
||||
|
||||
The position ids should be masked with the embedding object's padding index. Therefore, the
|
||||
first available non-padding position index is IBertEmbeddings.padding_idx + 1
|
||||
"""
|
||||
config = self.model_tester.prepare_config_and_inputs()[0]
|
||||
embeddings = IBertEmbeddings(config=config)
|
||||
|
||||
inputs_embeds = torch.Tensor(2, 4, 30)
|
||||
expected_single_positions = [
|
||||
0 + embeddings.padding_idx + 1,
|
||||
1 + embeddings.padding_idx + 1,
|
||||
2 + embeddings.padding_idx + 1,
|
||||
3 + embeddings.padding_idx + 1,
|
||||
]
|
||||
expected_positions = torch.as_tensor([expected_single_positions, expected_single_positions])
|
||||
position_ids = embeddings.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||
self.assertEqual(position_ids.shape, expected_positions.shape)
|
||||
self.assertTrue(torch.all(torch.eq(position_ids, expected_positions)))
|
||||
|
||||
# Override
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
self.assertIsInstance(model.get_input_embeddings(), QuantEmbedding)
|
||||
model.set_input_embeddings(torch.nn.Embedding(10, 10))
|
||||
x = model.get_output_embeddings()
|
||||
self.assertTrue(x is None or isinstance(x, torch.nn.Linear))
|
||||
|
||||
# Override
|
||||
def test_feed_forward_chunking(self):
|
||||
pass # I-BERT does not support chunking
|
||||
|
||||
# Override
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs["input_ids"]
|
||||
del inputs["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs["input_ids"]
|
||||
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
embed, embed_scaling_factor = wte(input_ids)
|
||||
inputs["inputs_embeds"] = embed
|
||||
else:
|
||||
inputs["inputs_embeds"] = wte(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
|
||||
|
||||
with torch.no_grad():
|
||||
model(**inputs)[0]
|
||||
|
||||
|
||||
@require_torch
|
||||
class IBertModelIntegrationTest(unittest.TestCase):
|
||||
def test_quant_embedding(self):
|
||||
weight_bit = 8
|
||||
embedding = QuantEmbedding(2, 4, quant_mode=True, weight_bit=weight_bit)
|
||||
embedding_weight = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
|
||||
embedding.weight = torch.nn.Parameter(embedding_weight)
|
||||
|
||||
expected_scaling_factor = embedding_weight.abs().max() / (2 ** (weight_bit - 1) - 1)
|
||||
x, x_scaling_factor = embedding(torch.tensor(0))
|
||||
y, y_scaling_factor = embedding(torch.tensor(1))
|
||||
|
||||
# scaling factor should follow the symmetric quantization rule
|
||||
self.assertTrue(torch.allclose(x_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(x_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
|
||||
# quantization error should not exceed the scaling factor
|
||||
self.assertTrue(torch.allclose(x, embedding_weight[0], atol=expected_scaling_factor))
|
||||
self.assertTrue(torch.allclose(y, embedding_weight[1], atol=expected_scaling_factor))
|
||||
|
||||
def test_quant_act(self):
|
||||
def _test_range():
|
||||
act = QuantAct(activation_bit, act_range_momentum, quant_mode=True)
|
||||
|
||||
# First pass
|
||||
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
|
||||
x_scaling_factor = torch.tensor(1.0)
|
||||
y, y_scaling_factor = act(x, x_scaling_factor)
|
||||
y_int = y / y_scaling_factor
|
||||
|
||||
# After the first pass, x_min and x_max should be initialized with x.min() and x.max()
|
||||
expected_x_min, expected_x_max = x.min(), x.max()
|
||||
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
|
||||
|
||||
# scaling factor should follow the symmetric quantization rule
|
||||
expected_range = torch.max(expected_x_min.abs(), expected_x_max.abs())
|
||||
expected_scaling_factor = expected_range / (2 ** (activation_bit - 1) - 1)
|
||||
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
|
||||
# quantization error should not exceed the scaling factor
|
||||
self.assertTrue(torch.allclose(x, y, atol=expected_scaling_factor))
|
||||
|
||||
# output should be integer
|
||||
self.assertTrue(torch.allclose(y_int, y_int.round(), atol=1e-4))
|
||||
|
||||
# Second Pass
|
||||
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]]) * 2
|
||||
x_scaling_factor = torch.tensor(1.0)
|
||||
y, y_scaling_factor = act(x, x_scaling_factor)
|
||||
y_int = y / y_scaling_factor
|
||||
|
||||
# From the second pass, x_min and x_max should be updated with moving average
|
||||
expected_x_min = expected_x_min * act_range_momentum + x.min() * (1 - act_range_momentum)
|
||||
expected_x_max = expected_x_max * act_range_momentum + x.max() * (1 - act_range_momentum)
|
||||
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
|
||||
|
||||
# scaling factor should follow the symmetric quantization rule
|
||||
expected_range = torch.max(expected_x_min.abs(), expected_x_max.abs())
|
||||
expected_scaling_factor = expected_range / (2 ** (activation_bit - 1) - 1)
|
||||
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
|
||||
# quantization error should not exceed the scaling factor
|
||||
x = x.clamp(min=-expected_range, max=expected_range)
|
||||
self.assertTrue(torch.allclose(x, y, atol=expected_scaling_factor))
|
||||
|
||||
# output should be integer
|
||||
self.assertTrue(torch.allclose(y_int, y_int.round(), atol=1e-4))
|
||||
|
||||
# Third pass, with eval()
|
||||
act.eval()
|
||||
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]]) * 3
|
||||
|
||||
# In eval mode, min/max and scaling factor must be fixed
|
||||
self.assertTrue(torch.allclose(act.x_min, expected_x_min, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(act.x_max, expected_x_max, atol=1e-4))
|
||||
self.assertTrue(torch.allclose(y_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
|
||||
def _test_identity():
|
||||
# test if identity and identity_scaling_factor are given
|
||||
# should add the input values
|
||||
act = QuantAct(activation_bit, act_range_momentum, quant_mode=True)
|
||||
x = torch.tensor([[-1.0, -2.0, -3.0, -4.0], [5.0, 6.0, 7.0, 8.0]])
|
||||
y = torch.tensor([[6.0, -7.0, 1.0, -2.0], [3.0, -4.0, -8.0, 5.0]])
|
||||
x_scaling_factor = torch.tensor(1.0)
|
||||
y_scaling_factor = torch.tensor(0.5)
|
||||
z, z_scaling_factor = act(x, x_scaling_factor, y, y_scaling_factor)
|
||||
z_int = z / z_scaling_factor
|
||||
self.assertTrue(torch.allclose(x + y, z, atol=0.1))
|
||||
self.assertTrue(torch.allclose(z_int, z_int.round(), atol=1e-4))
|
||||
|
||||
activation_bit = 8
|
||||
act_range_momentum = 0.95
|
||||
_test_range()
|
||||
_test_identity()
|
||||
|
||||
def test_quant_linear(self):
|
||||
def _test(per_channel):
|
||||
linear_q = QuantLinear(2, 4, quant_mode=True, per_channel=per_channel, weight_bit=weight_bit)
|
||||
linear_dq = QuantLinear(2, 4, quant_mode=False, per_channel=per_channel, weight_bit=weight_bit)
|
||||
linear_weight = torch.tensor([[-1.0, 2.0, 3.0, -4.0], [5.0, -6.0, -7.0, 8.0]]).T
|
||||
linear_q.weight = torch.nn.Parameter(linear_weight)
|
||||
linear_dq.weight = torch.nn.Parameter(linear_weight)
|
||||
|
||||
q, q_scaling_factor = linear_q(x, x_scaling_factor)
|
||||
q_int = q / q_scaling_factor
|
||||
dq, dq_scaling_factor = linear_dq(x, x_scaling_factor)
|
||||
|
||||
if per_channel:
|
||||
q_max = linear_weight.abs().max(dim=1).values
|
||||
else:
|
||||
q_max = linear_weight.abs().max()
|
||||
expected_scaling_factor = q_max / (2 ** (weight_bit - 1) - 1)
|
||||
|
||||
# scaling factor should follow the symmetric quantization rule
|
||||
self.assertTrue(torch.allclose(linear_q.fc_scaling_factor, expected_scaling_factor, atol=1e-4))
|
||||
|
||||
# output of the normal linear layer and the quantized linear layer should be similar
|
||||
self.assertTrue(torch.allclose(q, dq, atol=0.5))
|
||||
|
||||
# output of the quantized linear layer should be integer
|
||||
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
|
||||
|
||||
weight_bit = 8
|
||||
x = torch.tensor([[2.0, -5.0], [-3.0, 4.0]])
|
||||
x_scaling_factor = torch.tensor([1.0])
|
||||
_test(True)
|
||||
_test(False)
|
||||
|
||||
def test_int_gelu(self):
|
||||
gelu_q = IntGELU(quant_mode=True)
|
||||
gelu_dq = torch.nn.GELU()
|
||||
|
||||
x_int = torch.range(-10000, 10000, 1)
|
||||
x_scaling_factor = torch.tensor(0.001)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
q, q_scaling_factor = gelu_q(x, x_scaling_factor)
|
||||
q_int = q / q_scaling_factor
|
||||
dq = gelu_dq(x)
|
||||
|
||||
# output of the normal GELU and the quantized GELU should be similar
|
||||
self.assertTrue(torch.allclose(q, dq, atol=0.5))
|
||||
|
||||
# output of the quantized GELU layer should be integer
|
||||
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
|
||||
|
||||
def test_force_dequant_gelu(self):
|
||||
x_int = torch.range(-10000, 10000, 1)
|
||||
x_scaling_factor = torch.tensor(0.001)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
gelu_dq = IntGELU(quant_mode=False)
|
||||
gelu_fdqs_dict = {
|
||||
True: [
|
||||
IntGELU(quant_mode=True, force_dequant="nonlinear"),
|
||||
IntGELU(quant_mode=True, force_dequant="gelu"),
|
||||
],
|
||||
False: [
|
||||
IntGELU(quant_mode=True, force_dequant="none"),
|
||||
IntGELU(quant_mode=True, force_dequant="softmax"),
|
||||
IntGELU(quant_mode=True, force_dequant="layernorm"),
|
||||
],
|
||||
}
|
||||
|
||||
dq, dq_scaling_factor = gelu_dq(x, x_scaling_factor)
|
||||
for label, gelu_fdqs in gelu_fdqs_dict.items():
|
||||
for gelu_fdq in gelu_fdqs:
|
||||
q, q_scaling_factor = gelu_fdq(x, x_scaling_factor)
|
||||
if label:
|
||||
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
|
||||
|
||||
def test_int_softmax(self):
|
||||
output_bit = 8
|
||||
softmax_q = IntSoftmax(output_bit, quant_mode=True)
|
||||
softmax_dq = torch.nn.Softmax()
|
||||
|
||||
# x_int = torch.range(-10000, 10000, 1)
|
||||
def _test(array):
|
||||
x_int = torch.tensor(array)
|
||||
x_scaling_factor = torch.tensor(0.1)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
q, q_scaling_factor = softmax_q(x, x_scaling_factor)
|
||||
q_int = q / q_scaling_factor
|
||||
dq = softmax_dq(x)
|
||||
|
||||
# output of the normal Softmax and the quantized Softmax should be similar
|
||||
self.assertTrue(torch.allclose(q, dq, atol=0.5))
|
||||
|
||||
# output of the quantized GELU layer should be integer
|
||||
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
|
||||
|
||||
# Output of the quantize Softmax should not exceed the output_bit
|
||||
self.assertTrue(q.abs().max() < 2 ** output_bit)
|
||||
|
||||
array = [[i + j for j in range(10)] for i in range(-10, 10)]
|
||||
_test(array)
|
||||
array = [[i + j for j in range(50)] for i in range(-10, 10)]
|
||||
_test(array)
|
||||
array = [[i + 100 * j for j in range(2)] for i in range(-10, 10)]
|
||||
_test(array)
|
||||
|
||||
def test_force_dequant_softmax(self):
|
||||
output_bit = 8
|
||||
array = [[i + j for j in range(10)] for i in range(-10, 10)]
|
||||
x_int = torch.tensor(array)
|
||||
x_scaling_factor = torch.tensor(0.1)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
softmax_dq = IntSoftmax(output_bit, quant_mode=False)
|
||||
softmax_fdqs_dict = {
|
||||
True: [
|
||||
IntSoftmax(output_bit, quant_mode=True, force_dequant="nonlinear"),
|
||||
IntSoftmax(output_bit, quant_mode=True, force_dequant="softmax"),
|
||||
],
|
||||
False: [
|
||||
IntSoftmax(output_bit, quant_mode=True, force_dequant="none"),
|
||||
IntSoftmax(output_bit, quant_mode=True, force_dequant="gelu"),
|
||||
IntSoftmax(output_bit, quant_mode=True, force_dequant="layernorm"),
|
||||
],
|
||||
}
|
||||
|
||||
dq, dq_scaling_factor = softmax_dq(x, x_scaling_factor)
|
||||
for label, softmax_fdqs in softmax_fdqs_dict.items():
|
||||
for softmax_fdq in softmax_fdqs:
|
||||
q, q_scaling_factor = softmax_fdq(x, x_scaling_factor)
|
||||
if label:
|
||||
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
|
||||
|
||||
def test_int_layernorm(self):
|
||||
output_bit = 8
|
||||
|
||||
# some random matrix
|
||||
array = [[[i * j * j + j for j in range(5, 15)]] for i in range(-10, 10)]
|
||||
x_int = torch.tensor(array)
|
||||
x_scaling_factor = torch.tensor(0.1)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
ln_q = IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit)
|
||||
ln_dq = torch.nn.LayerNorm(x.shape[1:], 1e-5)
|
||||
|
||||
ln_q.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
ln_q.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
|
||||
q, q_scaling_factor = ln_q(x, x_scaling_factor)
|
||||
q_int = q / q_scaling_factor
|
||||
dq = ln_dq(x)
|
||||
|
||||
# output of the normal LN and the quantized LN should be similar
|
||||
self.assertTrue(torch.allclose(q, dq, atol=0.5))
|
||||
|
||||
# output of the quantized GELU layer should be integer
|
||||
self.assertTrue(torch.allclose(q_int, q_int.round(), atol=1e-4))
|
||||
|
||||
def test_force_dequant_layernorm(self):
|
||||
output_bit = 8
|
||||
array = [[[i * j * j + j for j in range(5, 15)]] for i in range(-10, 10)]
|
||||
x_int = torch.tensor(array)
|
||||
x_scaling_factor = torch.tensor(0.1)
|
||||
x = x_int * x_scaling_factor
|
||||
|
||||
ln_dq = IntLayerNorm(x.shape[1:], 1e-5, quant_mode=False, output_bit=output_bit)
|
||||
ln_fdqs_dict = {
|
||||
True: [
|
||||
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="nonlinear"),
|
||||
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="layernorm"),
|
||||
],
|
||||
False: [
|
||||
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="none"),
|
||||
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="gelu"),
|
||||
IntLayerNorm(x.shape[1:], 1e-5, quant_mode=True, output_bit=output_bit, force_dequant="softmax"),
|
||||
],
|
||||
}
|
||||
|
||||
ln_dq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
ln_dq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
dq, dq_scaling_factor = ln_dq(x, x_scaling_factor)
|
||||
for label, ln_fdqs in ln_fdqs_dict.items():
|
||||
for ln_fdq in ln_fdqs:
|
||||
ln_fdq.weight = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
ln_fdq.bias = torch.nn.Parameter(torch.ones(x.shape[1:]))
|
||||
q, q_scaling_factor = ln_fdq(x, x_scaling_factor)
|
||||
if label:
|
||||
self.assertTrue(torch.allclose(q, dq, atol=1e-4))
|
||||
else:
|
||||
self.assertFalse(torch.allclose(q, dq, atol=1e-4))
|
||||
|
||||
def quantize(self, model):
|
||||
# Helper function that quantizes the given model
|
||||
# Recursively convert all the `quant_mode` attributes as `True`
|
||||
if hasattr(model, "quant_mode"):
|
||||
model.quant_mode = True
|
||||
elif type(model) == nn.Sequential:
|
||||
for n, m in model.named_children():
|
||||
self.quantize(m)
|
||||
elif type(model) == nn.ModuleList:
|
||||
for n in model:
|
||||
self.quantize(n)
|
||||
else:
|
||||
for attr in dir(model):
|
||||
mod = getattr(model, attr)
|
||||
if isinstance(mod, nn.Module) and mod != model:
|
||||
self.quantize(mod)
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
# I-BERT should be "equivalent" to RoBERTa if not quantized
|
||||
# Test coped from `test_modeling_roberta.py`
|
||||
model = IBertForMaskedLM.from_pretrained("kssteven/ibert-roberta-base")
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 50265))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_slice = torch.tensor(
|
||||
[[[33.8802, -4.3103, 22.7761], [4.6539, -2.8098, 13.6253], [1.8228, -3.6898, 8.8600]]]
|
||||
)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
|
||||
|
||||
# I-BERT should be "similar" to RoBERTa if quantized
|
||||
self.quantize(model)
|
||||
output = model(input_ids)[0]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=0.1))
|
||||
|
||||
@slow
|
||||
def test_inference_classification_head(self):
|
||||
# I-BERT should be "equivalent" to RoBERTa if not quantized
|
||||
# Test coped from `test_modeling_roberta.py`
|
||||
model = IBertForSequenceClassification.from_pretrained("kssteven/ibert-roberta-large-mnli")
|
||||
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 3))
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
expected_tensor = torch.tensor([[-0.9469, 0.3913, 0.5118]])
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=1e-4))
|
||||
|
||||
# I-BERT should be "similar" to RoBERTa if quantized
|
||||
self.quantize(model)
|
||||
output = model(input_ids)[0]
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
self.assertTrue(torch.allclose(output, expected_tensor, atol=0.1))
|
Loading…
Reference in New Issue