From 0b524b084857d0bf54eb613304a61bcdbd71e6fb Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 5 Aug 2019 19:08:19 +0200 Subject: [PATCH] remove derived classes for now --- docs/source/model_doc/auto.rst | 21 +- pytorch_transformers/__init__.py | 2 +- pytorch_transformers/modeling_auto.py | 266 ------------------ .../tests/modeling_auto_test.py | 10 +- 4 files changed, 4 insertions(+), 295 deletions(-) diff --git a/docs/source/model_doc/auto.rst b/docs/source/model_doc/auto.rst index 43f6e103bd..7b56eabafe 100644 --- a/docs/source/model_doc/auto.rst +++ b/docs/source/model_doc/auto.rst @@ -3,12 +3,9 @@ AutoModels In many cases, the architecture you want to use can be guessed from the name or the path of the pretrained model you are supplying to the ``from_pretrained`` method. -AutoClasses are here to do this job for you so that you automatically retreive the relevant model given the name/path to the pretrained weights/config/vocabulary. +AutoClasses are here to do this job for you so that you automatically retreive the relevant model given the name/path to the pretrained weights/config/vocabulary: -There are two types of AutoClasses: - -- ``AutoModel``, ``AutoConfig`` and ``AutoTokenizer``: instantiating these ones will directly create a class of the relevant architecture (ex: ``model = AutoModel.from_pretrained('bert-base-cased')`` will create a instance of ``BertModel``) -- All the others (``AutoModelWithLMHead``, ``AutoModelForSequenceClassification``...) are standardized Auto classes for finetuning. Instantiating these will create instance of the same class (``AutoModelWithLMHead``, ``AutoModelForSequenceClassification``...) comprising (i) the relevant base model class (as mentioned just above) and (ii) a standard fine-tuning head on top, convenient for the task. +Instantiating one of ``AutoModel``, ``AutoConfig`` and ``AutoTokenizer`` will directly create a class of the relevant architecture (ex: ``model = AutoModel.from_pretrained('bert-base-cased')`` will create a instance of ``BertModel``). ``AutoConfig`` @@ -25,20 +22,6 @@ There are two types of AutoClasses: :members: -``AutoModelWithLMHead`` -~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_transformers.AutoModelWithLMHead - :members: - - -``AutoModelForSequenceClassification`` -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. autoclass:: pytorch_transformers.AutoModelForSequenceClassification - :members: - - ``AutoTokenizer`` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 110c3dc3c7..04e5c3c9dd 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -8,7 +8,7 @@ from .tokenization_xlnet import XLNetTokenizer, SPIECE_UNDERLINE from .tokenization_xlm import XLMTokenizer from .tokenization_utils import (PreTrainedTokenizer) -from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelWithLMHead) +from .modeling_auto import (AutoConfig, AutoModel) from .modeling_bert import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 22a35090aa..64b151e3a3 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -234,269 +234,3 @@ class AutoModel(object): raise ValueError("Unrecognized model identifier in {}. Should contains one of " "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " "'xlm'".format(pretrained_model_name_or_path)) - - -class DerivedAutoModel(PreTrainedModel): - r""" - :class:`~pytorch_transformers.DerivedAutoModel` is a base class for building - standardized derived models on top of :class:`~pytorch_transformers.AutoModel` by adding heads - - The `from_pretrained()` method take care of using the correct base model class instance - using pattern matching on the `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - This class should usually not be instantiated using `__init__()` but `from_pretrained()`. - """ - config_class = None - pretrained_model_archive_map = {} - load_tf_weights = lambda model, config, path: None - base_model_prefix = "transformer" - - def __init__(self, base_model): - super(DerivedAutoModel, self).__init__(base_model.config) - self.transformer = base_model - - def init_weights(self, module): - """ Initialize the weights. Use the base model initialization function. - """ - self.transformer.init_weights(module) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - r""" Instantiate a :class:`~pytorch_transformers.DerivedAutoModel` with one of the base model classes of the library - from a pre-trained model configuration. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated) - To train the model, you should first set it back in training mode with `model.train()` - - Params: - **pretrained_model_name_or_path**: either: - - a string with the `shortcut name` of a pre-trained model to load from cache - or download and cache if not already stored in cache (e.g. 'bert-base-uncased'). - - a path to a `directory` containing a configuration file saved - using the `save_pretrained(save_directory)` method. - - a path or url to a tensorflow index checkpoint `file` (e.g. `./tf_model/model.ckpt.index`). - In this case, ``from_tf`` should be set to True and a configuration object should be - provided as `config` argument. This loading option is slower than converting the TensorFlow - checkpoint in a PyTorch model using the provided conversion scripts and loading - the PyTorch model afterwards. - **model_args**: (`optional`) Sequence: - All remaning positional arguments will be passed to the underlying model's __init__ function - **config**: an optional configuration for the model to use instead of an automatically loaded configuation. - Configuration can be automatically loaded when: - - the model is a model provided by the library (loaded with a `shortcut name` of a pre-trained model), or - - the model was saved using the `save_pretrained(save_directory)` (loaded by suppling the save directory). - **state_dict**: an optional state dictionnary for the model to use instead of a state dictionary loaded - from saved weights file. - This option can be used if you want to create a model from a pretrained configuration but load your own weights. - In this case though, you should check if using `save_pretrained(dir)` and `from_pretrained(save_directory)` is not - a simpler option. - **cache_dir**: (`optional`) string: - Path to a directory in which a downloaded pre-trained model - configuration should be cached if the standard cache should not be used. - **output_loading_info**: (`optional`) boolean: - Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. - **kwargs**: (`optional`) dict: - Dictionary of key, values to update the configuration object after loading. - Can be used to override selected configuration parameters. E.g. ``output_attention=True``. - - - If a configuration is provided with `config`, **kwargs will be directly passed - to the underlying model's __init__ method. - - If a configuration is not provided, **kwargs will be first passed to the pretrained - model configuration class loading function (`PretrainedConfig.from_pretrained`). - Each key of **kwargs that corresponds to a configuration attribute - will be used to override said attribute with the supplied **kwargs value. - Remaining keys that do not correspond to any configuration attribute will - be passed to the underlying model's __init__ function. - - Examples:: - - model = AutoModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. - model = AutoModel.from_pretrained('./test/bert_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` - model = AutoModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading - assert model.config.output_attention == True - # Loading from a TF checkpoint file instead of a PyTorch model (slower) - config = AutoConfig.from_json_file('./tf_model/bert_tf_model_config.json') - model = AutoModel.from_pretrained('./tf_model/bert_tf_checkpoint.ckpt.index', from_tf=True, config=config) - - """ - if 'bert' in pretrained_model_name_or_path: - base_model_class = BertModel - elif 'openai-gpt' in pretrained_model_name_or_path: - base_model_class = OpenAIGPTModel - elif 'gpt2' in pretrained_model_name_or_path: - base_model_class = GPT2Model - elif 'transfo-xl' in pretrained_model_name_or_path: - base_model_class = TransfoXLModel - elif 'xlnet' in pretrained_model_name_or_path: - base_model_class = XLNetModel - elif 'xlm' in pretrained_model_name_or_path: - base_model_class = XLMModel - else: - raise ValueError("Unrecognized model identifier in {}. Should contains one of " - "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm'".format(pretrained_model_name_or_path)) - - # Get a pretrained base_model - base_model = base_model_class.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - # Create our derived model - model = cls(base_model) - - # Setup class attribute from the base model class - model.config_class = base_model.config_class - model.pretrained_model_archive_map = base_model.pretrained_model_archive_map - model.load_tf_weights = base_model.load_tf_weights - - return model - - -class AutoModelWithLMHead(DerivedAutoModel): - r""" - :class:`~pytorch_transformers.AutoModelWithLMHead` is a base class for language modeling - that contains - - - a base model instantiated as one of the base model classes of the library when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class method, and . - - a language modeling head on top of the base model. - - The `from_pretrained()` method take care of using the correct base model class instance - using pattern matching on the `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - This class should usually not be instantiated using `__init__()` but `from_pretrained()`. - """ - - def __init__(self, base_model): - super(AutoModelWithLMHead, self).__init__(base_model) - config = base_model.config - - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - self.apply(self.init_weights) - self.tie_weights() - - def tie_weights(self): - """ Make sure we are sharing the input and output embeddings. - Export to TorchScript can't handle parameter sharing so we are cloning them instead. - """ - # get input embeddings - whatever the model is - input_embeddings = self.transformer.resize_token_embeddings(new_num_tokens=None) - - # tie of clone (torchscript) embeddings - self._tie_or_clone_weights(self.lm_head, input_embeddings) - - def forward(self, input_ids, **kwargs): - labels = kwargs.pop('labels', None) # Python 2 compatibility... - - transformer_outputs = self.transformer(input_ids, **kwargs) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) - - outputs = (lm_logits,) + transformer_outputs[1:] - if labels is not None: - loss_fct = CrossEntropyLoss(ignore_index=-1) - loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), - labels.view(-1)) - outputs = (loss,) + outputs - - return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) - - -AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS = { - 'num_labels': 2, - 'summary_type': 'first', - 'summary_use_proj': True, - 'summary_activation': None, - 'summary_proj_to_labels': True, - 'summary_first_dropout': 0.1 -} - - - -class AutoModelForSequenceClassification(DerivedAutoModel): - r""" - :class:`~pytorch_transformers.AutoModelForSequenceClassification` is a class for sequence classification - that contains - - - a base model instantiated as one of the base model classes of the library when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class method, and . - - a classification head on top of the base model. - - The `from_pretrained()` method take care of using the correct base model class instance - using pattern matching on the `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - This class should usually not be instantiated using `__init__()` but `from_pretrained()`. - """ - - def __init__(self, base_model): - super(AutoModelForSequenceClassification, self).__init__(base_model) - # Complete configuration with defaults if necessary - config = base_model.config - for key, value in AUTO_MODEL_SEQUENCE_SUMMARY_DEFAULTS.items(): - if not hasattr(config, key): - setattr(config, key, value) - - # Update base model and derived model config - self.transformer.config = config - self.config = config - - self.num_labels = config.num_labels - self.sequence_summary = SequenceSummary(config) - - self.apply(self.init_weights) - - def forward(self, input_ids, cls_index, **kwargs): - labels = kwargs.pop('labels', None) # Python 2 compatibility... - - transformer_outputs = self.transformer(input_ids, **kwargs) - - output = transformer_outputs[0] - logits = self.sequence_summary(output, cls_index=cls_index) - - outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here - - if labels is not None: - if self.num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - loss = loss_fct(logits.view(-1), labels.view(-1)) - else: - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - outputs = (loss,) + outputs - - return outputs diff --git a/pytorch_transformers/tests/modeling_auto_test.py b/pytorch_transformers/tests/modeling_auto_test.py index 07042a255c..d0c830abc7 100644 --- a/pytorch_transformers/tests/modeling_auto_test.py +++ b/pytorch_transformers/tests/modeling_auto_test.py @@ -21,7 +21,7 @@ import shutil import pytest import logging -from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel, AutoModelForSequenceClassification, AutoModelWithLMHead +from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) @@ -42,14 +42,6 @@ class AutoModelTest(unittest.TestCase): for value in loading_info.values(): self.assertEqual(len(value), 0) - model = AutoModelForSequenceClassification.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel) - - model = AutoModelWithLMHead.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertIsInstance(getattr(model, model.base_model_prefix), BertModel) - if __name__ == "__main__": unittest.main()