Add onnx export of models with a multiple choice classification head (#16758)

* Add export of models with a multiple-choice classification head
This commit is contained in:
Ella Charlaix 2022-04-19 15:51:51 +02:00 committed by GitHub
parent b74a955325
commit 77de8d6c31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 134 additions and 41 deletions

View File

@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig):
class AlbertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig):
class BertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig):
class BigBirdOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig):
class CamembertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig):
class Data2VecTextOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig):
class DistilBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig):
class ElectraOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("token_type_ids", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
("token_type_ids", dynamic_axis),
]
)

View File

@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig):
class FlaubertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()

View File

@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()

View File

@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None,
) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# We need to order the input in the way they appears in the forward()

View File

@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig):
class IBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig):
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
"""
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
input_dict = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# Generate a dummy bbox
box = [48, 84, 73, 128]

View File

@ -70,9 +70,13 @@ class RobertaConfig(BertConfig):
class RobertaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig):
class XLMRobertaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig):
class XLMRobertaXLOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
("input_ids", dynamic_axis),
("attention_mask", dynamic_axis),
]
)

View File

@ -71,6 +71,7 @@ class OnnxConfig(ABC):
default_fixed_batch = 2
default_fixed_sequence = 8
default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
@ -174,6 +175,16 @@ class OnnxConfig(ABC):
"""
return OnnxConfig.default_fixed_sequence
@property
def default_num_choices(self) -> int:
"""
The default number of choices to use if no other indication
Returns:
Integer > 0
"""
return OnnxConfig.default_fixed_num_choices
@property
def default_onnx_opset(self) -> int:
"""
@ -240,6 +251,7 @@ class OnnxConfig(ABC):
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
num_channels: int = 3,
@ -255,6 +267,8 @@ class OnnxConfig(ABC):
The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis).
num_choices (`int`, *optional*, defaults to -1):
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`):
@ -295,6 +309,19 @@ class OnnxConfig(ABC):
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
num_choices = compute_effective_axis_dimension(
num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0
)
dummy_input = dummy_input * num_choices
# The shape of the tokenized inputs values is [batch_size * num_choices, seq_length]
tokenized_input = preprocessor(dummy_input, text_pair=dummy_input)
# Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length]
for k, v in tokenized_input.items():
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
return dict(preprocessor(dummy_input, return_tensors=framework))
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
) -> Mapping[str, Any]:
# TODO: should we set seq_length = 1 when self.use_past = True?
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
if self.use_past:
if not is_torch_available():
@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
) -> Mapping[str, Any]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework
tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework
)
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs)

View File

@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
@ -120,7 +121,7 @@ class FeaturesManager:
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=AlbertOnnxConfig,
@ -152,7 +153,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=BertOnnxConfig,
@ -162,6 +163,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=BigBirdOnnxConfig,
@ -170,7 +172,7 @@ class FeaturesManager:
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
@ -180,7 +182,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
@ -189,7 +191,7 @@ class FeaturesManager:
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=DistilBertOnnxConfig,
@ -199,6 +201,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=FlaubertOnnxConfig,
@ -220,7 +223,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=RobertaOnnxConfig,
@ -233,7 +236,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
# "multiple-choice",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=XLMRobertaOnnxConfig,
@ -276,6 +279,7 @@ class FeaturesManager:
"masked-lm",
"causal-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=ElectraOnnxConfig,
@ -300,6 +304,15 @@ class FeaturesManager:
"seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig,
),
"data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
),
}
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))

View File

@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
}
PYTORCH_EXPORT_WITH_PAST_MODELS = {