Add onnx export of models with a multiple choice classification head (#16758)

* Add export of models with a multiple-choice classification head
This commit is contained in:
Ella Charlaix 2022-04-19 15:51:51 +02:00 committed by GitHub
parent b74a955325
commit 77de8d6c31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 134 additions and 41 deletions

View File

@ -159,10 +159,14 @@ class AlbertConfig(PretrainedConfig):
class AlbertOnnxConfig(OnnxConfig): class AlbertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )

View File

@ -160,10 +160,14 @@ class BertConfig(PretrainedConfig):
class BertOnnxConfig(OnnxConfig): class BertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )

View File

@ -168,9 +168,13 @@ class BigBirdConfig(PretrainedConfig):
class BigBirdOnnxConfig(OnnxConfig): class BigBirdOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -44,9 +44,13 @@ class CamembertConfig(RobertaConfig):
class CamembertOnnxConfig(OnnxConfig): class CamembertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -139,9 +139,13 @@ class Data2VecTextConfig(PretrainedConfig):
class Data2VecTextOnnxConfig(OnnxConfig): class Data2VecTextOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -134,9 +134,13 @@ class DistilBertConfig(PretrainedConfig):
class DistilBertOnnxConfig(OnnxConfig): class DistilBertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -179,10 +179,14 @@ class ElectraConfig(PretrainedConfig):
class ElectraOnnxConfig(OnnxConfig): class ElectraOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
("token_type_ids", {0: "batch", 1: "sequence"}), ("token_type_ids", dynamic_axis),
] ]
) )

View File

@ -146,9 +146,13 @@ class FlaubertConfig(XLMConfig):
class FlaubertOnnxConfig(OnnxConfig): class FlaubertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -234,7 +234,7 @@ class GPT2OnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()

View File

@ -233,7 +233,7 @@ class GPTNeoOnnxConfig(OnnxConfigWithPast):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()

View File

@ -183,7 +183,7 @@ class GPTJOnnxConfig(OnnxConfigWithPast):
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# We need to order the input in the way they appears in the forward() # We need to order the input in the way they appears in the forward()

View File

@ -131,9 +131,13 @@ class IBertConfig(PretrainedConfig):
class IBertOnnxConfig(OnnxConfig): class IBertOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -171,7 +171,9 @@ class LayoutLMOnnxConfig(OnnxConfig):
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
""" """
input_dict = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) input_dict = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
# Generate a dummy bbox # Generate a dummy bbox
box = [48, 84, 73, 128] box = [48, 84, 73, 128]

View File

@ -70,9 +70,13 @@ class RobertaConfig(BertConfig):
class RobertaOnnxConfig(OnnxConfig): class RobertaOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -47,9 +47,13 @@ class XLMRobertaConfig(RobertaConfig):
class XLMRobertaOnnxConfig(OnnxConfig): class XLMRobertaOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -143,9 +143,13 @@ class XLMRobertaXLConfig(PretrainedConfig):
class XLMRobertaXLOnnxConfig(OnnxConfig): class XLMRobertaXLOnnxConfig(OnnxConfig):
@property @property
def inputs(self) -> Mapping[str, Mapping[int, str]]: def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
return OrderedDict( return OrderedDict(
[ [
("input_ids", {0: "batch", 1: "sequence"}), ("input_ids", dynamic_axis),
("attention_mask", {0: "batch", 1: "sequence"}), ("attention_mask", dynamic_axis),
] ]
) )

View File

@ -71,6 +71,7 @@ class OnnxConfig(ABC):
default_fixed_batch = 2 default_fixed_batch = 2
default_fixed_sequence = 8 default_fixed_sequence = 8
default_fixed_num_choices = 4
torch_onnx_minimum_version = version.parse("1.8") torch_onnx_minimum_version = version.parse("1.8")
_tasks_to_common_outputs = { _tasks_to_common_outputs = {
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
@ -174,6 +175,16 @@ class OnnxConfig(ABC):
""" """
return OnnxConfig.default_fixed_sequence return OnnxConfig.default_fixed_sequence
@property
def default_num_choices(self) -> int:
"""
The default number of choices to use if no other indication
Returns:
Integer > 0
"""
return OnnxConfig.default_fixed_num_choices
@property @property
def default_onnx_opset(self) -> int: def default_onnx_opset(self) -> int:
""" """
@ -240,6 +251,7 @@ class OnnxConfig(ABC):
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"], preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1, batch_size: int = -1,
seq_length: int = -1, seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False, is_pair: bool = False,
framework: Optional[TensorType] = None, framework: Optional[TensorType] = None,
num_channels: int = 3, num_channels: int = 3,
@ -255,6 +267,8 @@ class OnnxConfig(ABC):
The preprocessor associated with this model configuration. The preprocessor associated with this model configuration.
batch_size (`int`, *optional*, defaults to -1): batch_size (`int`, *optional*, defaults to -1):
The batch size to export the model for (-1 means dynamic axis). The batch size to export the model for (-1 means dynamic axis).
num_choices (`int`, *optional*, defaults to -1):
The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
seq_length (`int`, *optional*, defaults to -1): seq_length (`int`, *optional*, defaults to -1):
The sequence length to export the model for (-1 means dynamic axis). The sequence length to export the model for (-1 means dynamic axis).
is_pair (`bool`, *optional*, defaults to `False`): is_pair (`bool`, *optional*, defaults to `False`):
@ -295,6 +309,19 @@ class OnnxConfig(ABC):
) )
# Generate dummy inputs according to compute batch and sequence # Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
num_choices = compute_effective_axis_dimension(
num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0
)
dummy_input = dummy_input * num_choices
# The shape of the tokenized inputs values is [batch_size * num_choices, seq_length]
tokenized_input = preprocessor(dummy_input, text_pair=dummy_input)
# Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length]
for k, v in tokenized_input.items():
tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
return dict(preprocessor(dummy_input, return_tensors=framework)) return dict(preprocessor(dummy_input, return_tensors=framework))
elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values": elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
@ -408,7 +435,9 @@ class OnnxConfigWithPast(OnnxConfig, ABC):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
# TODO: should we set seq_length = 1 when self.use_past = True? # TODO: should we set seq_length = 1 when self.use_past = True?
common_inputs = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) common_inputs = super().generate_dummy_inputs(
tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
)
if self.use_past: if self.use_past:
if not is_torch_available(): if not is_torch_available():
@ -527,13 +556,13 @@ class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
) -> Mapping[str, Any]: ) -> Mapping[str, Any]:
encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
) )
# Generate decoder inputs # Generate decoder inputs
decoder_seq_length = seq_length if not self.use_past else 1 decoder_seq_length = seq_length if not self.use_past else 1
decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
tokenizer, batch_size, decoder_seq_length, is_pair, framework tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework
) )
decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()} decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
common_inputs = dict(**encoder_inputs, **decoder_inputs) common_inputs = dict(**encoder_inputs, **decoder_inputs)

View File

@ -10,6 +10,7 @@ from ..models.big_bird import BigBirdOnnxConfig
from ..models.blenderbot import BlenderbotOnnxConfig from ..models.blenderbot import BlenderbotOnnxConfig
from ..models.blenderbot_small import BlenderbotSmallOnnxConfig from ..models.blenderbot_small import BlenderbotSmallOnnxConfig
from ..models.camembert import CamembertOnnxConfig from ..models.camembert import CamembertOnnxConfig
from ..models.data2vec import Data2VecTextOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig from ..models.flaubert import FlaubertOnnxConfig
@ -120,7 +121,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=AlbertOnnxConfig, onnx_config_cls=AlbertOnnxConfig,
@ -152,7 +153,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=BertOnnxConfig, onnx_config_cls=BertOnnxConfig,
@ -162,6 +163,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=BigBirdOnnxConfig, onnx_config_cls=BigBirdOnnxConfig,
@ -170,7 +172,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=IBertOnnxConfig, onnx_config_cls=IBertOnnxConfig,
@ -180,7 +182,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=CamembertOnnxConfig, onnx_config_cls=CamembertOnnxConfig,
@ -189,7 +191,7 @@ class FeaturesManager:
"default", "default",
"masked-lm", "masked-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=DistilBertOnnxConfig, onnx_config_cls=DistilBertOnnxConfig,
@ -199,6 +201,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=FlaubertOnnxConfig, onnx_config_cls=FlaubertOnnxConfig,
@ -220,7 +223,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=RobertaOnnxConfig, onnx_config_cls=RobertaOnnxConfig,
@ -233,7 +236,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
# "multiple-choice", "multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=XLMRobertaOnnxConfig, onnx_config_cls=XLMRobertaOnnxConfig,
@ -276,6 +279,7 @@ class FeaturesManager:
"masked-lm", "masked-lm",
"causal-lm", "causal-lm",
"sequence-classification", "sequence-classification",
"multiple-choice",
"token-classification", "token-classification",
"question-answering", "question-answering",
onnx_config_cls=ElectraOnnxConfig, onnx_config_cls=ElectraOnnxConfig,
@ -300,6 +304,15 @@ class FeaturesManager:
"seq2seq-lm-with-past", "seq2seq-lm-with-past",
onnx_config_cls=BlenderbotSmallOnnxConfig, onnx_config_cls=BlenderbotSmallOnnxConfig,
), ),
"data2vec-text": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=Data2VecTextOnnxConfig,
),
} }
AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values()))) AVAILABLE_FEATURES = sorted(reduce(lambda s1, s2: s1 | s2, (v.keys() for v in _SUPPORTED_MODEL_TYPE.values())))

View File

@ -182,6 +182,7 @@ PYTORCH_EXPORT_MODELS = {
("layoutlm", "microsoft/layoutlm-base-uncased"), ("layoutlm", "microsoft/layoutlm-base-uncased"),
("vit", "google/vit-base-patch16-224"), ("vit", "google/vit-base-patch16-224"),
("beit", "microsoft/beit-base-patch16-224"), ("beit", "microsoft/beit-base-patch16-224"),
("data2vec-text", "facebook/data2vec-text-base"),
} }
PYTORCH_EXPORT_WITH_PAST_MODELS = { PYTORCH_EXPORT_WITH_PAST_MODELS = {