Added BigBirdPegasus onnx config (#17104)
* Add onnx configuration for bigbird-pegasus * Modify docs
This commit is contained in:
parent
351cdbdfdc
commit
215e0681e4
|
@ -50,6 +50,7 @@ Ready-made configurations include the following architectures:
|
|||
- BEiT
|
||||
- BERT
|
||||
- BigBird
|
||||
- BigBirdPegasus
|
||||
- Blenderbot
|
||||
- BlenderbotSmall
|
||||
- CamemBERT
|
||||
|
|
|
@ -21,7 +21,11 @@ from ...utils import _LazyModule, is_torch_available
|
|||
|
||||
|
||||
_import_structure = {
|
||||
"configuration_bigbird_pegasus": ["BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP", "BigBirdPegasusConfig"],
|
||||
"configuration_bigbird_pegasus": [
|
||||
"BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
||||
"BigBirdPegasusConfig",
|
||||
"BigBirdPegasusOnnxConfig",
|
||||
],
|
||||
}
|
||||
|
||||
if is_torch_available():
|
||||
|
@ -37,7 +41,11 @@ if is_torch_available():
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_bigbird_pegasus import BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP, BigBirdPegasusConfig
|
||||
from .configuration_bigbird_pegasus import (
|
||||
BIGBIRD_PEGASUS_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BigBirdPegasusConfig,
|
||||
BigBirdPegasusOnnxConfig,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_bigbird_pegasus import (
|
||||
|
|
|
@ -14,8 +14,14 @@
|
|||
# limitations under the License.
|
||||
""" BigBirdPegasus model configuration"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Mapping, Optional
|
||||
|
||||
from ... import PreTrainedTokenizer
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...utils import logging
|
||||
from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
|
||||
from ...onnx.utils import compute_effective_axis_dimension
|
||||
from ...utils import TensorType, is_torch_available, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -185,3 +191,226 @@ class BigBirdPegasusConfig(PretrainedConfig):
|
|||
decoder_start_token_id=decoder_start_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.models.bart.configuration_bart.BartOnnxConfig
|
||||
class BigBirdPegasusOnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
elif self.task == "causal-lm":
|
||||
# TODO: figure this case out.
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
]
|
||||
)
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
else:
|
||||
common_inputs = OrderedDict(
|
||||
[
|
||||
("input_ids", {0: "batch", 1: "encoder_sequence"}),
|
||||
("attention_mask", {0: "batch", 1: "encoder_sequence"}),
|
||||
("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
|
||||
("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
|
||||
]
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_outputs = super().outputs
|
||||
else:
|
||||
common_outputs = super(OnnxConfigWithPast, self).outputs
|
||||
if self.use_past:
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
for i in range(num_encoder_layers):
|
||||
common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
|
||||
return common_outputs
|
||||
|
||||
def _generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
# Generate decoder inputs
|
||||
decoder_seq_length = seq_length if not self.use_past else 1
|
||||
decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
||||
)
|
||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch, encoder_seq_length = common_inputs["input_ids"].shape
|
||||
decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
|
||||
num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
|
||||
encoder_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
encoder_seq_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
decoder_past_length = decoder_seq_length + 3
|
||||
decoder_shape = (
|
||||
batch,
|
||||
num_decoder_attention_heads,
|
||||
decoder_past_length,
|
||||
self._config.hidden_size // num_decoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["decoder_attention_mask"] = torch.cat(
|
||||
[common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
|
||||
)
|
||||
|
||||
common_inputs["past_key_values"] = []
|
||||
# If the number of encoder and decoder layers are present in the model configuration, both are considered
|
||||
num_encoder_layers, num_decoder_layers = self.num_layers
|
||||
min_num_layers = min(num_encoder_layers, num_decoder_layers)
|
||||
max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
|
||||
remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
|
||||
|
||||
for _ in range(min_num_layers):
|
||||
common_inputs["past_key_values"].append(
|
||||
(
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(decoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
torch.zeros(encoder_shape),
|
||||
)
|
||||
)
|
||||
# TODO: test this.
|
||||
shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
|
||||
for _ in range(min_num_layers, max_num_layers):
|
||||
common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
|
||||
return common_inputs
|
||||
|
||||
def _generate_dummy_inputs_for_causal_lm(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||
tokenizer, batch_size, seq_length, is_pair, framework
|
||||
)
|
||||
|
||||
if self.use_past:
|
||||
if not is_torch_available():
|
||||
raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
|
||||
else:
|
||||
import torch
|
||||
batch, seqlen = common_inputs["input_ids"].shape
|
||||
# Not using the same length for past_key_values
|
||||
past_key_values_length = seqlen + 2
|
||||
num_encoder_layers, _ = self.num_layers
|
||||
num_encoder_attention_heads, _ = self.num_attention_heads
|
||||
past_shape = (
|
||||
batch,
|
||||
num_encoder_attention_heads,
|
||||
past_key_values_length,
|
||||
self._config.hidden_size // num_encoder_attention_heads,
|
||||
)
|
||||
|
||||
common_inputs["attention_mask"] = torch.cat(
|
||||
[common_inputs["attention_mask"], torch.ones(batch, past_key_values_length)], dim=1
|
||||
)
|
||||
common_inputs["past_key_values"] = [
|
||||
(torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
|
||||
]
|
||||
return common_inputs
|
||||
|
||||
def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
# Copied from OnnxConfig.generate_dummy_inputs
|
||||
# Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
|
||||
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||
batch_size = compute_effective_axis_dimension(
|
||||
batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
|
||||
)
|
||||
|
||||
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
||||
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
||||
seq_length = compute_effective_axis_dimension(
|
||||
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
|
||||
)
|
||||
|
||||
# Generate dummy inputs according to compute batch and sequence
|
||||
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
||||
common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
|
||||
return common_inputs
|
||||
|
||||
def generate_dummy_inputs(
|
||||
self,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int = -1,
|
||||
seq_length: int = -1,
|
||||
is_pair: bool = False,
|
||||
framework: Optional[TensorType] = None,
|
||||
) -> Mapping[str, Any]:
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
elif self.task == "causal-lm":
|
||||
common_inputs = self._generate_dummy_inputs_for_causal_lm(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
else:
|
||||
common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
|
||||
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||
)
|
||||
|
||||
return common_inputs
|
||||
|
||||
def _flatten_past_key_values_(self, flattened_output, name, idx, t):
|
||||
if self.task in ["default", "seq2seq-lm"]:
|
||||
flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
|
||||
else:
|
||||
flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
|
||||
flattened_output, name, idx, t
|
||||
)
|
||||
|
|
|
@ -7,6 +7,7 @@ from ..models.bart import BartOnnxConfig
|
|||
from ..models.beit import BeitOnnxConfig
|
||||
from ..models.bert import BertOnnxConfig
|
||||
from ..models.big_bird import BigBirdOnnxConfig
|
||||
from ..models.bigbird_pegasus import BigBirdPegasusOnnxConfig
|
||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||
from ..models.camembert import CamembertOnnxConfig
|
||||
|
@ -164,6 +165,17 @@ class FeaturesManager:
|
|||
"question-answering",
|
||||
onnx_config_cls=BigBirdOnnxConfig,
|
||||
),
|
||||
"bigbird-pegasus": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"causal-lm",
|
||||
"causal-lm-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
"sequence-classification",
|
||||
"question-answering",
|
||||
onnx_config_cls=BigBirdPegasusOnnxConfig,
|
||||
),
|
||||
"blenderbot": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
|
|
|
@ -201,6 +201,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||
("m2m-100", "facebook/m2m100_418M"),
|
||||
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||
("blenderbot", "facebook/blenderbot-400M-distill"),
|
||||
("bigbird-pegasus", "google/bigbird-pegasus-large-arxiv"),
|
||||
}
|
||||
|
||||
# TODO(lewtun): Include the same model types in `PYTORCH_EXPORT_MODELS` once TensorFlow has parity with the PyTorch model implementations.
|
||||
|
|
Loading…
Reference in New Issue