Add onnx export of models with a multiple choice classification head (#16758)
* Add export of models with a multiple-choice classification head
This commit is contained in:
parent
b74a955325
commit
77de8d6c31
|
@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig):
|
||||||
class AlbertOnnxConfig(OnnxConfig):
|
class AlbertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
("token_type_ids", {0: "batch", 1: "sequence"}),
|
("token_type_ids", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig):
|
||||||
class BertOnnxConfig(OnnxConfig):
|
class BertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
("token_type_ids", {0: "batch", 1: "sequence"}),
|
("token_type_ids", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig):
|
||||||
class BigBirdOnnxConfig(OnnxConfig):
|
class BigBirdOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig):
|
||||||
class CamembertOnnxConfig(OnnxConfig):
|
class CamembertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig):
|
||||||
class Data2VecTextOnnxConfig(OnnxConfig):
|
class Data2VecTextOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig):
|
||||||
class DistilBertOnnxConfig(OnnxConfig):
|
class DistilBertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig):
|
||||||
class ElectraOnnxConfig(OnnxConfig):
|
class ElectraOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
("token_type_ids", {0: "batch", 1: "sequence"}),
|
("token_type_ids", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig):
|
||||||
class FlaubertOnnxConfig(OnnxConfig):
|
class FlaubertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size, seq_length, is_pair, framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to order the input in the way they appears in the forward()
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
|
|
@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size, seq_length, is_pair, framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to order the input in the way they appears in the forward()
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
|
|
@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size, seq_length, is_pair, framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to order the input in the way they appears in the forward()
|
# We need to order the input in the way they appears in the forward()
|
||||||
|
|
|
@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig):
|
||||||
class IBertOnnxConfig(OnnxConfig):
|
class IBertOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig):
|
||||||
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
input_dict = super().generate_dummy_inputs(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
# Generate a dummy bbox
|
# Generate a dummy bbox
|
||||||
box = [48, 84, 73, 128]
|
box = [48, 84, 73, 128]
|
||||||
|
|
|
@ -70,9 +70,13 @@ class RobertaConfig(BertConfig):
|
||||||
class RobertaOnnxConfig(OnnxConfig):
|
class RobertaOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig):
|
||||||
class XLMRobertaOnnxConfig(OnnxConfig):
|
class XLMRobertaOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig):
|
||||||
class XLMRobertaXLOnnxConfig(OnnxConfig):
|
class XLMRobertaXLOnnxConfig(OnnxConfig):
|
||||||
@property
|
@property
|
||||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
return OrderedDict(
|
return OrderedDict(
|
||||||
[
|
[
|
||||||
("input_ids", {0: "batch", 1: "sequence"}),
|
("input_ids", dynamic_axis),
|
||||||
("attention_mask", {0: "batch", 1: "sequence"}),
|
("attention_mask", dynamic_axis),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
|
@ -71,6 +71,7 @@ class OnnxConfig(ABC):
|
||||||
|
|
||||||
default_fixed_batch = 2
|
default_fixed_batch = 2
|
||||||
default_fixed_sequence = 8
|
default_fixed_sequence = 8
|
||||||
|
default_fixed_num_choices = 4
|
||||||
torch_onnx_minimum_version = version.parse("1.8")
|
torch_onnx_minimum_version = version.parse("1.8")
|
||||||
_tasks_to_common_outputs = {
|
_tasks_to_common_outputs = {
|
||||||
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
|
||||||
|
@ -174,6 +175,16 @@ class OnnxConfig(ABC):
|
||||||
"""
|
"""
|
||||||
return OnnxConfig.default_fixed_sequence
|
return OnnxConfig.default_fixed_sequence
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_num_choices(self) -> int:
|
||||||
|
"""
|
||||||
|
The default number of choices to use if no other indication
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer > 0
|
||||||
|
"""
|
||||||
|
return OnnxConfig.default_fixed_num_choices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def default_onnx_opset(self) -> int:
|
def default_onnx_opset(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
@ -240,6 +251,7 @@ class OnnxConfig(ABC):
|
||||||
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
|
||||||
batch_size: int = -1,
|
batch_size: int = -1,
|
||||||
seq_length: int = -1,
|
seq_length: int = -1,
|
||||||
|
num_choices: int = -1,
|
||||||
is_pair: bool = False,
|
is_pair: bool = False,
|
||||||
framework: Optional[TensorType] = None,
|
framework: Optional[TensorType] = None,
|
||||||
num_channels: int = 3,
|
num_channels: int = 3,
|
||||||
|
@ -255,6 +267,8 @@ class OnnxConfig(ABC):
|
||||||
The preprocessor associated with this model configuration.
|
The preprocessor associated with this model configuration.
|
||||||
batch_size (`int`, *optional*, defaults to -1):
|
batch_size (`int`, *optional*, defaults to -1):
|
||||||
The batch size to export the model for (-1 means dynamic axis).
|
The batch size to export the model for (-1 means dynamic axis).
|
||||||
|
num_choices (`int`, *optional*, defaults to -1):
|
||||||
|
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
|
||||||
seq_length (`int`, *optional*, defaults to -1):
|
seq_length (`int`, *optional*, defaults to -1):
|
||||||
The sequence length to export the model for (-1 means dynamic axis).
|
The sequence length to export the model for (-1 means dynamic axis).
|
||||||
is_pair (`bool`, *optional*, defaults to `False`):
|
is_pair (`bool`, *optional*, defaults to `False`):
|
||||||
|
@ -295,6 +309,19 @@ class OnnxConfig(ABC):
|
||||||
)
|
)
|
||||||
# Generate dummy inputs according to compute batch and sequence
|
# Generate dummy inputs according to compute batch and sequence
|
||||||
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
|
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
|
||||||
|
# made by ONNX
|
||||||
|
num_choices = compute_effective_axis_dimension(
|
||||||
|
num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0
|
||||||
|
)
|
||||||
|
dummy_input = dummy_input * num_choices
|
||||||
|
# The shape of the tokenized inputs values is [batch_size * num_choices, seq_length]
|
||||||
|
tokenized_input = preprocessor(dummy_input, text_pair=dummy_input)
|
||||||
|
# Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length]
|
||||||
|
for k, v in tokenized_input.items():
|
||||||
|
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
|
||||||
|
return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
|
||||||
return dict(preprocessor(dummy_input, return_tensors=framework))
|
return dict(preprocessor(dummy_input, return_tensors=framework))
|
||||||
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
|
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
|
||||||
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
||||||
|
@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
# TODO: should we set seq_length = 1 when self.use_past = True?
|
# TODO: should we set seq_length = 1 when self.use_past = True?
|
||||||
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
|
common_inputs = super().generate_dummy_inputs(
|
||||||
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
|
||||||
if self.use_past:
|
if self.use_past:
|
||||||
if not is_torch_available():
|
if not is_torch_available():
|
||||||
|
@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
|
||||||
) -> Mapping[str, Any]:
|
) -> Mapping[str, Any]:
|
||||||
|
|
||||||
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size, seq_length, is_pair, framework
|
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
|
|
||||||
# Generate decoder inputs
|
# Generate decoder inputs
|
||||||
decoder_seq_length = seq_length if not self.use_past else 1
|
decoder_seq_length = seq_length if not self.use_past else 1
|
||||||
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
|
||||||
tokenizer, batch_size, decoder_seq_length, is_pair, framework
|
tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework
|
||||||
)
|
)
|
||||||
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
|
||||||
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
common_inputs = dict(**encoder_inputs, **decoder_inputs)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
|
||||||
from ..models.blenderbot import BlenderbotOnnxConfig
|
from ..models.blenderbot import BlenderbotOnnxConfig
|
||||||
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
|
||||||
from ..models.camembert import CamembertOnnxConfig
|
from ..models.camembert import CamembertOnnxConfig
|
||||||
|
from ..models.data2vec import Data2VecTextOnnxConfig
|
||||||
from ..models.distilbert import DistilBertOnnxConfig
|
from ..models.distilbert import DistilBertOnnxConfig
|
||||||
from ..models.electra import ElectraOnnxConfig
|
from ..models.electra import ElectraOnnxConfig
|
||||||
from ..models.flaubert import FlaubertOnnxConfig
|
from ..models.flaubert import FlaubertOnnxConfig
|
||||||
|
@ -120,7 +121,7 @@ class FeaturesManager:
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=AlbertOnnxConfig,
|
onnx_config_cls=AlbertOnnxConfig,
|
||||||
|
@ -152,7 +153,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=BertOnnxConfig,
|
onnx_config_cls=BertOnnxConfig,
|
||||||
|
@ -162,6 +163,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=BigBirdOnnxConfig,
|
onnx_config_cls=BigBirdOnnxConfig,
|
||||||
|
@ -170,7 +172,7 @@ class FeaturesManager:
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=IBertOnnxConfig,
|
onnx_config_cls=IBertOnnxConfig,
|
||||||
|
@ -180,7 +182,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=CamembertOnnxConfig,
|
onnx_config_cls=CamembertOnnxConfig,
|
||||||
|
@ -189,7 +191,7 @@ class FeaturesManager:
|
||||||
"default",
|
"default",
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=DistilBertOnnxConfig,
|
onnx_config_cls=DistilBertOnnxConfig,
|
||||||
|
@ -199,6 +201,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=FlaubertOnnxConfig,
|
onnx_config_cls=FlaubertOnnxConfig,
|
||||||
|
@ -220,7 +223,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=RobertaOnnxConfig,
|
onnx_config_cls=RobertaOnnxConfig,
|
||||||
|
@ -233,7 +236,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
# "multiple-choice",
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=XLMRobertaOnnxConfig,
|
onnx_config_cls=XLMRobertaOnnxConfig,
|
||||||
|
@ -276,6 +279,7 @@ class FeaturesManager:
|
||||||
"masked-lm",
|
"masked-lm",
|
||||||
"causal-lm",
|
"causal-lm",
|
||||||
"sequence-classification",
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
"token-classification",
|
"token-classification",
|
||||||
"question-answering",
|
"question-answering",
|
||||||
onnx_config_cls=ElectraOnnxConfig,
|
onnx_config_cls=ElectraOnnxConfig,
|
||||||
|
@ -300,6 +304,15 @@ class FeaturesManager:
|
||||||
"seq2seq-lm-with-past",
|
"seq2seq-lm-with-past",
|
||||||
onnx_config_cls=BlenderbotSmallOnnxConfig,
|
onnx_config_cls=BlenderbotSmallOnnxConfig,
|
||||||
),
|
),
|
||||||
|
"data2vec-text": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"sequence-classification",
|
||||||
|
"multiple-choice",
|
||||||
|
"token-classification",
|
||||||
|
"question-answering",
|
||||||
|
onnx_config_cls=Data2VecTextOnnxConfig,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))
|
||||||
|
|
|
@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
|
||||||
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
("layoutlm", "microsoft/layoutlm-base-uncased"),
|
||||||
("vit", "google/vit-base-patch16-224"),
|
("vit", "google/vit-base-patch16-224"),
|
||||||
("beit", "microsoft/beit-base-patch16-224"),
|
("beit", "microsoft/beit-base-patch16-224"),
|
||||||
|
("data2vec-text", "facebook/data2vec-text-base"),
|
||||||
}
|
}
|
||||||
|
|
||||||
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
PYTORCH_EXPORT_WITH_PAST_MODELS = {
|
||||||
|
|
Loading…
Reference in New Issue