From 215e0681e4c3f6ade6e219d022a5e640b42fcb76 Mon Sep 17 00:00:00 2001 From: Ritik Nandwal <48522685+nandwalritik@users.noreply.github.com> Date: Fri, 6 May 2022 21:01:00 +0530 Subject: [PATCH] Added BigBirdPegasus onnx config (#17104) * Add onnx configuration for bigbird-pegasus * Modify docs --- docs/source/en/serialization.mdx | 1 + .../models/bigbird_pegasus/__init__.py | 12 +- .../configuration_bigbird_pegasus.py | 231 +++++++++++++++++- src/transformers/onnx/features.py | 12 + tests/onnx/test_onnx_v2.py | 1 + 5 files changed, 254 insertions(+), 3 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 4ae5c9a57e..0866632020 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -50,6 +50,7 @@ Ready-made configurations include the following architectures: - BEiT - BERT - BigBird +- BigBirdPegasus - Blenderbot - BlenderbotSmall - CamemBERT diff --git a/src/transformers/models/bigbird_pegasus/__init__.py b/src/transformers/models/bigbird_pegasus/__init__.py index fd5b3671e2..879bea21dd 100644 --- a/src/transformers/models/bigbird_pegasus/__init__.py +++ b/src/transformers/models/bigbird_pegasus/__init__.py @@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_torch_available _import_structure = { - "configuration_bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig"], + "configuration_bigbird_pegasus": [ + "BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BigBirdPegasusConfig", + "BigBirdPegasusOnnxConfig", + ], } if is_torch_available(): @@ -37,7 +41,11 @@ if is_torch_available(): if TYPE_CHECKING: - from .configuration_bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig + from .configuration_bigbird_pegasus import ( + BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, + BigBirdPegasusConfig, + BigBirdPegasusOnnxConfig, + ) if is_torch_available(): from .modeling_bigbird_pegasus import ( diff --git a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py index 7c9c0ec5a0..c198c84aa7 100644 --- a/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py @@ -14,8 +14,14 @@ # limitations under the License. """ BigBirdPegasus model configuration""" +from collections import OrderedDict +from typing import Any, Mapping, Optional + +from ... import PreTrainedTokenizer from ...configuration_utils import PretrainedConfig -from ...utils import logging +from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast +from ...onnx.utils import compute_effective_axis_dimension +from ...utils import TensorType, is_torch_available, logging logger = logging.get_logger(__name__) @@ -185,3 +191,226 @@ class BigBirdPegasusConfig(PretrainedConfig): decoder_start_token_id=decoder_start_token_id, **kwargs, ) + + +# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig +class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast): + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + + if self.use_past: + common_inputs["decoder_input_ids"] = {0: "batch"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"} + else: + common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"} + common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"} + + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + elif self.task == "causal-lm": + # TODO: figure this case out. + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ] + ) + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + else: + common_inputs = OrderedDict( + [ + ("input_ids", {0: "batch", 1: "encoder_sequence"}), + ("attention_mask", {0: "batch", 1: "encoder_sequence"}), + ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}), + ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}), + ] + ) + + return common_inputs + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task in ["default", "seq2seq-lm"]: + common_outputs = super().outputs + else: + common_outputs = super(OnnxConfigWithPast, self).outputs + if self.use_past: + num_encoder_layers, _ = self.num_layers + for i in range(num_encoder_layers): + common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"} + common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"} + return common_outputs + + def _generate_dummy_inputs_for_default_and_seq2seq_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + # Generate decoder inputs + decoder_seq_length = seq_length if not self.use_past else 1 + decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, decoder_seq_length, is_pair, framework + ) + decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} + common_inputs = dict(**encoder_inputs, **decoder_inputs) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, encoder_seq_length = common_inputs["input_ids"].shape + decoder_seq_length = common_inputs["decoder_input_ids"].shape[1] + num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads + encoder_shape = ( + batch, + num_encoder_attention_heads, + encoder_seq_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + decoder_past_length = decoder_seq_length + 3 + decoder_shape = ( + batch, + num_decoder_attention_heads, + decoder_past_length, + self._config.hidden_size // num_decoder_attention_heads, + ) + + common_inputs["decoder_attention_mask"] = torch.cat( + [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1 + ) + + common_inputs["past_key_values"] = [] + # If the number of encoder and decoder layers are present in the model configuration, both are considered + num_encoder_layers, num_decoder_layers = self.num_layers + min_num_layers = min(num_encoder_layers, num_decoder_layers) + max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers + remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder" + + for _ in range(min_num_layers): + common_inputs["past_key_values"].append( + ( + torch.zeros(decoder_shape), + torch.zeros(decoder_shape), + torch.zeros(encoder_shape), + torch.zeros(encoder_shape), + ) + ) + # TODO: test this. + shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape + for _ in range(min_num_layers, max_num_layers): + common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape))) + return common_inputs + + def _generate_dummy_inputs_for_causal_lm( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size, seq_length, is_pair, framework + ) + + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + num_encoder_layers, _ = self.num_layers + num_encoder_attention_heads, _ = self.num_attention_heads + past_shape = ( + batch, + num_encoder_attention_heads, + past_key_values_length, + self._config.hidden_size // num_encoder_attention_heads, + ) + + common_inputs["attention_mask"] = torch.cat( + [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1 + ) + common_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers) + ] + return common_inputs + + def _generate_dummy_inputs_for_sequence_classification_and_question_answering( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + # Copied from OnnxConfig.generate_dummy_inputs + # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity. + # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX + batch_size = compute_effective_axis_dimension( + batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0 + ) + + # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX + token_to_add = tokenizer.num_special_tokens_to_add(is_pair) + seq_length = compute_effective_axis_dimension( + seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add + ) + + # Generate dummy inputs according to compute batch and sequence + dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size + common_inputs = dict(tokenizer(dummy_input, return_tensors=framework)) + return common_inputs + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + if self.task in ["default", "seq2seq-lm"]: + common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + elif self.task == "causal-lm": + common_inputs = self._generate_dummy_inputs_for_causal_lm( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + else: + common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + return common_inputs + + def _flatten_past_key_values_(self, flattened_output, name, idx, t): + if self.task in ["default", "seq2seq-lm"]: + flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t) + else: + flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_( + flattened_output, name, idx, t + ) diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 50e941332a..852a8e9071 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig from ..models.beit import BeitOnnxConfig from ..models.bert import BertOnnxConfig from ..models.big_bird import BigBirdOnnxConfig +from ..models.bigbird_pegasus import BigBirdPegasusOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.camembert import CamembertOnnxConfig @@ -164,6 +165,17 @@ class FeaturesManager: "question-answering", onnx_config_cls=BigBirdOnnxConfig, ), + "bigbird-pegasus": supported_features_mapping( + "default", + "default-with-past", + "causal-lm", + "causal-lm-with-past", + "seq2seq-lm", + "seq2seq-lm-with-past", + "sequence-classification", + "question-answering", + onnx_config_cls=BigBirdPegasusOnnxConfig, + ), "blenderbot": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 4ecfc917d5..9fc228f895 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -201,6 +201,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = { ("m2m-100", "facebook/m2m100_418M"), ("blenderbot-small", "facebook/blenderbot_small-90M"), ("blenderbot", "facebook/blenderbot-400M-distill"), + ("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"), } # TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.