Add mt5 onnx config (#18394)
* update features * MT5OnnxConfig added with updated with tests and docs * fix imports * fix onnc_config_cls for mt5 Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai>
This commit is contained in:
parent
fe785730dc
commit
8cb5ecd912
|
@ -79,6 +79,7 @@ Ready-made configurations include the following architectures:
|
|||
- mBART
|
||||
- MobileBERT
|
||||
- MobileViT
|
||||
- MT5
|
||||
- OpenAI GPT-2
|
||||
- Perceiver
|
||||
- PLBart
|
||||
|
|
|
@ -43,7 +43,7 @@ else:
|
|||
|
||||
MT5TokenizerFast = T5TokenizerFast
|
||||
|
||||
_import_structure = {"configuration_mt5": ["MT5Config"]}
|
||||
_import_structure = {"configuration_mt5": ["MT5Config", "MT5OnnxConfig"]}
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
@ -71,7 +71,7 @@ else:
|
|||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .configuration_mt5 import MT5Config
|
||||
from .configuration_mt5 import MT5Config, MT5OnnxConfig
|
||||
|
||||
try:
|
||||
if not is_torch_available():
|
||||
|
|
|
@ -13,8 +13,10 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" mT5 model configuration"""
|
||||
from typing import Mapping
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...onnx import OnnxSeq2SeqConfigWithPast
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
@ -143,3 +145,29 @@ class MT5Config(PretrainedConfig):
|
|||
@property
|
||||
def num_hidden_layers(self):
|
||||
return self.num_layers
|
||||
|
||||
|
||||
# Copied from transformers.models.t5.configuration_t5.T5OnnxConfig
|
||||
class MT5OnnxConfig(OnnxSeq2SeqConfigWithPast):
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
common_inputs = {
|
||||
"input_ids": {0: "batch", 1: "encoder_sequence"},
|
||||
"attention_mask": {0: "batch", 1: "encoder_sequence"},
|
||||
}
|
||||
if self.use_past:
|
||||
common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
|
||||
common_inputs["decoder_input_ids"] = {0: "batch"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
|
||||
else:
|
||||
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
|
||||
common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
|
||||
|
||||
if self.use_past:
|
||||
self.fill_with_past_key_values_(common_inputs, direction="inputs")
|
||||
|
||||
return common_inputs
|
||||
|
||||
@property
|
||||
def default_onnx_opset(self) -> int:
|
||||
return 13
|
||||
|
|
|
@ -383,6 +383,13 @@ class FeaturesManager:
|
|||
"image-classification",
|
||||
onnx_config_cls="models.mobilevit.MobileViTOnnxConfig",
|
||||
),
|
||||
"mt5": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
"seq2seq-lm",
|
||||
"seq2seq-lm-with-past",
|
||||
onnx_config_cls="models.mt5.MT5OnnxConfig",
|
||||
),
|
||||
"m2m-100": supported_features_mapping(
|
||||
"default",
|
||||
"default-with-past",
|
||||
|
|
|
@ -224,6 +224,7 @@ PYTORCH_EXPORT_SEQ2SEQ_WITH_PAST_MODELS = {
|
|||
("mbart", "sshleifer/tiny-mbart"),
|
||||
("t5", "t5-small"),
|
||||
("marian", "Helsinki-NLP/opus-mt-en-de"),
|
||||
("mt5", "google/mt5-base"),
|
||||
("m2m-100", "facebook/m2m100_418M"),
|
||||
("blenderbot-small", "facebook/blenderbot_small-90M"),
|
||||
("blenderbot", "facebook/blenderbot-400M-distill"),
|
||||
|
|
Loading…
Reference in New Issue