Add mt5 onnx config (#18394)

* update features

* MT5OnnxConfig added with updated with tests and docs

* fix imports

* fix onnc_config_cls for mt5

Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai>
This commit is contained in:
Thomas Chaigneau 2022-08-09 09:46:53 +02:00 committed by GitHub
parent fe785730dc
commit 8cb5ecd912
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 39 additions and 2 deletions

View File

@ -79,6 +79,7 @@ Ready-made configurations include the following architectures:
- mBART
- MobileBERT
- MobileViT
- MT5
- OpenAI GPT-2
- Perceiver
- PLBart

View File

@ -43,7 +43,7 @@ else:
MT5TokenizerFast = T5TokenizerFast
_import_structure = {"configuration_mt5": ["MT5Config"]}
_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]}
try:
if not is_torch_available():
@ -71,7 +71,7 @@ else:
if TYPE_CHECKING:
from .configuration_mt5 import MT5Config
from .configuration_mt5 import MT5Config, MT5OnnxConfig
try:
if not is_torch_available():

View File

@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 model configuration"""
from typing import Mapping
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxSeq2SeqConfigWithPast
from ...utils import logging
@ -143,3 +145,29 @@ class MT5Config(PretrainedConfig):
@property
def num_hidden_layers(self):
return self.num_layers
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = {
"input_ids": {0: "batch", 1: "encoder_sequence"},
"attention_mask": {0: "batch", 1: "encoder_sequence"},
}
if self.use_past:
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
common_inputs["decoder_input_ids"] = {0: "batch"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
return common_inputs
@property
def default_onnx_opset(self) -> int:
return 13

View File

@ -383,6 +383,13 @@ class FeaturesManager:
"image-classification",
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
),
"mt5": supported_features_mapping(
"default",
"default-with-past",
"seq2seq-lm",
"seq2seq-lm-with-past",
onnx_config_cls="models.mt5.MT5OnnxConfig",
),
"m2m-100": supported_features_mapping(
"default",
"default-with-past",

View File

@ -224,6 +224,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
("mbart", "sshleifer/tiny-mbart"),
("t5", "t5-small"),
("marian", "Helsinki-NLP/opus-mt-en-de"),
("mt5", "google/mt5-base"),
("m2m-100", "facebook/m2m100_418M"),
("blenderbot-small", "facebook/blenderbot_small-90M"),
("blenderbot", "facebook/blenderbot-400M-distill"),