diff --git a/pytorch_transformers/__init__.py b/pytorch_transformers/__init__.py index 3e8719bd8d..04a73c3abc 100644 --- a/pytorch_transformers/__init__.py +++ b/pytorch_transformers/__init__.py @@ -1,4 +1,7 @@ __version__ = "1.2.0" + +# Tokenizer +from .tokenization_utils import (PreTrainedTokenizer) from .tokenization_auto import AutoTokenizer from .tokenization_bert import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .tokenization_openai import OpenAIGPTTokenizer @@ -9,46 +12,51 @@ from .tokenization_xlm import XLMTokenizer from .tokenization_roberta import RobertaTokenizer from .tokenization_distilbert import DistilBertTokenizer -from .tokenization_utils import (PreTrainedTokenizer) +# Configurations +from .configuration_utils import CONFIG_NAME, PretrainedConfig +from .configuration_auto import AutoConfig +from .configuration_bert import BertConfig, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_openai import OpenAIGPTConfig, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_transfo_xl import TransfoXLConfig, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_gpt2 import GPT2Config, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_xlnet import XLNetConfig, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_xlm import XLMConfig, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP -from .modeling_auto import (AutoConfig, AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, +# Modeling +from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, PreTrainedModel, prune_layer, Conv1D) +from .modeling_auto import (AutoModel, AutoModelForSequenceClassification, AutoModelForQuestionAnswering, AutoModelWithLMHead) -from .modeling_bert import (BertConfig, BertPreTrainedModel, BertModel, BertForPreTraining, +from .modeling_bert import (BertPreTrainedModel, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, BertForSequenceClassification, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, - load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP) -from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTPreTrainedModel, OpenAIGPTModel, + load_tf_weights_in_bert, BERT_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_openai import (OpenAIGPTPreTrainedModel, OpenAIGPTModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, - load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, - OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, - load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, - TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model, + load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, + load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel, - load_tf_weights_in_gpt2, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, - GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_xlnet import (XLNetConfig, - XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, + load_tf_weights_in_gpt2, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_xlnet import (XLNetPreTrainedModel, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering, - load_tf_weights_in_xlnet, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel, + load_tf_weights_in_xlnet, XLNET_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_xlm import (XLMPreTrainedModel , XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, - XLMForQuestionAnswering, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, - XLM_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_roberta import (RobertaConfig, RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, - ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_distilbert import (DistilBertConfig, DistilBertForMaskedLM, DistilBertModel, + XLMForQuestionAnswering, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_roberta import (RobertaForMaskedLM, RobertaModel, RobertaForSequenceClassification, + ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP) +from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, DistilBertForSequenceClassification, DistilBertForQuestionAnswering, - DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) -from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, TF_WEIGHTS_NAME, - PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) + DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) +# Optimization from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, WarmupCosineWithHardRestartsSchedule, WarmupLinearSchedule) -from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path) +# Files and general utilities +from .file_utils import (PYTORCH_TRANSFORMERS_CACHE, PYTORCH_PRETRAINED_BERT_CACHE, cached_path, add_start_docstrings, add_end_docstrings) diff --git a/pytorch_transformers/configuration_auto.py b/pytorch_transformers/configuration_auto.py new file mode 100644 index 0000000000..9e35f85dc7 --- /dev/null +++ b/pytorch_transformers/configuration_auto.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2018 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Auto Model class. """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging + +from .configuration_bert import BertConfig +from .configuration_openai import OpenAIGPTConfig +from .configuration_gpt2 import GPT2Config +from .configuration_transfo_xl import TransfoXLConfig +from .configuration_xlnet import XLNetConfig +from .configuration_xlm import XLMConfig +from .configuration_roberta import RobertaConfig +from .configuration_distilbert import DistilBertConfig + +logger = logging.getLogger(__name__) + + +class AutoConfig(object): + r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class + that will be instantiated as one of the configuration classes of the library + when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` + class method. + + The `from_pretrained()` method take care of returning the correct model class instance + using pattern matching on the `pretrained_model_name_or_path` string. + + The base model class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertConfig (DistilBERT model) + - contains `bert`: BertConfig (Bert model) + - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) + - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) + - contains `xlnet`: XLNetConfig (XLNet model) + - contains `xlm`: XLMConfig (XLM model) + - contains `roberta`: RobertaConfig (RoBERTa model) + + This class cannot be instantiated using `__init__()` (throw an error). + """ + def __init__(self): + raise EnvironmentError("AutoConfig is designed to be instantiated " + "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" Instantiate a one of the configuration classes of the library + from a pre-trained model configuration. + + The configuration class to instantiate is selected as the first pattern matching + in the `pretrained_model_name_or_path` string (in the following order): + - contains `distilbert`: DistilBertConfig (DistilBERT model) + - contains `bert`: BertConfig (Bert model) + - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) + - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) + - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) + - contains `xlnet`: XLNetConfig (XLNet model) + - contains `xlm`: XLMConfig (XLM model) + - contains `roberta`: RobertaConfig (RoBERTa model) + + Params: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. + - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. + + - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. + - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + return_unused_kwargs: (`optional`) bool: + + - If False, then this function returns just the final configuration object. + - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. + + Examples:: + + config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` + config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') + config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) + assert config.output_attention == True + config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, + foo=False, return_unused_kwargs=True) + assert config.output_attention == True + assert unused_kwargs == {'foo': False} + + """ + if 'distilbert' in pretrained_model_name_or_path: + return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'roberta' in pretrained_model_name_or_path: + return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'bert' in pretrained_model_name_or_path: + return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'openai-gpt' in pretrained_model_name_or_path: + return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'gpt2' in pretrained_model_name_or_path: + return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'transfo-xl' in pretrained_model_name_or_path: + return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'xlnet' in pretrained_model_name_or_path: + return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + elif 'xlm' in pretrained_model_name_or_path: + return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + + raise ValueError("Unrecognized model identifier in {}. Should contains one of " + "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " + "'xlm', 'roberta'".format(pretrained_model_name_or_path)) diff --git a/pytorch_transformers/configuration_bert.py b/pytorch_transformers/configuration_bert.py new file mode 100644 index 0000000000..7fff3e5d05 --- /dev/null +++ b/pytorch_transformers/configuration_bert.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" BERT model configuration """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", + 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", + 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", + 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", + 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", + 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", + 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", + 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", + 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", + 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", + 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", + 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", + 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", +} + + +class BertConfig(PretrainedConfig): + r""" + :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a + `BertModel`. + + + Arguments: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. + hidden_size: Size of the encoder layers and the pooler layer. + num_hidden_layers: Number of hidden layers in the Transformer encoder. + num_attention_heads: Number of attention heads for each attention layer in + the Transformer encoder. + intermediate_size: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + hidden_act: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + hidden_dropout_prob: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attention_probs_dropout_prob: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size: The vocabulary size of the `token_type_ids` passed into + `BertModel`. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + """ + pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=30522, + hidden_size=768, + num_hidden_layers=12, + num_attention_heads=12, + intermediate_size=3072, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=2, + initializer_range=0.02, + layer_norm_eps=1e-12, + **kwargs): + super(BertConfig, self).__init__(**kwargs) + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.hidden_act = hidden_act + self.intermediate_size = intermediate_size + self.hidden_dropout_prob = hidden_dropout_prob + self.attention_probs_dropout_prob = attention_probs_dropout_prob + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + else: + raise ValueError("First argument must be either a vocabulary size (int)" + " or the path to a pretrained model config file (str)") diff --git a/pytorch_transformers/configuration_distilbert.py b/pytorch_transformers/configuration_distilbert.py new file mode 100644 index 0000000000..b8929eedec --- /dev/null +++ b/pytorch_transformers/configuration_distilbert.py @@ -0,0 +1,89 @@ +# coding=utf-8 +# Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" DistilBERT model configuration """ +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import sys +import json +import logging +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", + 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" +} + + +class DistilBertConfig(PretrainedConfig): + pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=30522, + max_position_embeddings=512, + sinusoidal_pos_embds=True, + n_layers=6, + n_heads=12, + dim=768, + hidden_dim=4*768, + dropout=0.1, + attention_dropout=0.1, + activation='gelu', + initializer_range=0.02, + tie_weights_=True, + qa_dropout=0.1, + seq_classif_dropout=0.2, + **kwargs): + super(DistilBertConfig, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.max_position_embeddings = max_position_embeddings + self.sinusoidal_pos_embds = sinusoidal_pos_embds + self.n_layers = n_layers + self.n_heads = n_heads + self.dim = dim + self.hidden_dim = hidden_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation = activation + self.initializer_range = initializer_range + self.tie_weights_ = tie_weights_ + self.qa_dropout = qa_dropout + self.seq_classif_dropout = seq_classif_dropout + else: + raise ValueError("First argument must be either a vocabulary size (int)" + " or the path to a pretrained model config file (str)") + @property + def hidden_size(self): + return self.dim + + @property + def num_attention_heads(self): + return self.n_heads + + @property + def num_hidden_layers(self): + return self.n_layers diff --git a/pytorch_transformers/configuration_gpt2.py b/pytorch_transformers/configuration_gpt2.py new file mode 100644 index 0000000000..c83d9e82ce --- /dev/null +++ b/pytorch_transformers/configuration_gpt2.py @@ -0,0 +1,143 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OpenAI GPT-2 configuration """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", + "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", + "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} + +class GPT2Config(PretrainedConfig): + """Configuration class to store the configuration of a `GPT2Model`. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. + n_positions: Number of positional embeddings. + n_ctx: Size of the causal mask (usually same as n_positions). + n_embd: Dimensionality of the embeddings and hidden states. + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + layer_norm_epsilon: epsilon to use in the layer norm layers + resid_pdrop: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attn_pdrop: The dropout ratio for the attention + probabilities. + embd_pdrop: The dropout ratio for the embeddings. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__( + self, + vocab_size_or_config_json_file=50257, + n_positions=1024, + n_ctx=1024, + n_embd=768, + n_layer=12, + n_head=12, + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + + num_labels=1, + summary_type='cls_index', + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + **kwargs + ): + """Constructs GPT2Config. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. + n_positions: Number of positional embeddings. + n_ctx: Size of the causal mask (usually same as n_positions). + n_embd: Dimensionality of the embeddings and hidden states. + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + layer_norm_epsilon: epsilon to use in the layer norm layers + resid_pdrop: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attn_pdrop: The dropout ratio for the attention + probabilities. + embd_pdrop: The dropout ratio for the embeddings. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + """ + super(GPT2Config, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + + self.num_labels = num_labels + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + else: + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/pytorch_transformers/configuration_openai.py b/pytorch_transformers/configuration_openai.py new file mode 100644 index 0000000000..b27df56899 --- /dev/null +++ b/pytorch_transformers/configuration_openai.py @@ -0,0 +1,135 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" OpenAI GPT configuration """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json" +} + +class OpenAIGPTConfig(PretrainedConfig): + """ + Configuration class to store the configuration of a `OpenAIGPTModel`. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. + n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) + n_positions: Number of positional embeddings. + n_ctx: Size of the causal mask (usually same as n_positions). + n_embd: Dimensionality of the embeddings and hidden states. + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + afn: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + resid_pdrop: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + attn_pdrop: The dropout ratio for the attention + probabilities. + embd_pdrop: The dropout ratio for the embeddings. + layer_norm_epsilon: epsilon to use in the layer norm layers + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + predict_special_tokens: should we predict special tokens (when the model has a LM head) + """ + pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__( + self, + vocab_size_or_config_json_file=40478, + n_positions=512, + n_ctx=512, + n_embd=768, + n_layer=12, + n_head=12, + afn="gelu", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + predict_special_tokens=True, + + num_labels=1, + summary_type='cls_index', + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + **kwargs + ): + """Constructs OpenAIGPTConfig. + """ + super(OpenAIGPTConfig, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.vocab_size = vocab_size_or_config_json_file + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.afn = afn + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.predict_special_tokens = predict_special_tokens + + self.num_labels = num_labels + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + else: + raise ValueError( + "First argument must be either a vocabulary size (int)" + "or the path to a pretrained model config file (str)" + ) + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/pytorch_transformers/configuration_roberta.py b/pytorch_transformers/configuration_roberta.py new file mode 100644 index 0000000000..b92d6a908b --- /dev/null +++ b/pytorch_transformers/configuration_roberta.py @@ -0,0 +1,35 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" RoBERTa configuration """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging + +from .configuration_bert import BertConfig + +logger = logging.getLogger(__name__) + +ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", + 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", + 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", +} + + +class RobertaConfig(BertConfig): + pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP diff --git a/pytorch_transformers/configuration_transfo_xl.py b/pytorch_transformers/configuration_transfo_xl.py new file mode 100644 index 0000000000..2e966ee55c --- /dev/null +++ b/pytorch_transformers/configuration_transfo_xl.py @@ -0,0 +1,167 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Transformer XL configuration """ + +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", +} + +class TransfoXLConfig(PretrainedConfig): + """Configuration class to store the configuration of a `TransfoXLModel`. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. + cutoffs: cutoffs for the adaptive softmax + d_model: Dimensionality of the model's hidden states. + d_embed: Dimensionality of the embeddings + d_head: Dimensionality of the model's heads. + div_val: divident value for adapative input and softmax + pre_lnorm: apply LayerNorm to the input instead of the output + d_inner: Inner dimension in FF + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + tgt_len: number of tokens to predict + ext_len: length of the extended context + mem_len: length of the retained previous heads + same_length: use the same attn length for all tokens + proj_share_all_but_first: True to share all but first projs, False not to share. + attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. + clamp_len: use the same pos embeddings after clamp_len + sample_softmax: number of samples in sampled softmax + adaptive: use adaptive softmax + tie_weight: tie the word embedding and softmax weights + dropout: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + dropatt: The dropout ratio for the attention probabilities. + untie_r: untie relative position biases + embd_pdrop: The dropout ratio for the embeddings. + init: parameter initializer to use + init_range: parameters initialized by U(-init_range, init_range). + proj_init_std: parameters initialized by N(0, init_std) + init_std: parameters initialized by N(0, init_std) + """ + pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=267735, + cutoffs=[20000, 40000, 200000], + d_model=1024, + d_embed=1024, + n_head=16, + d_head=64, + d_inner=4096, + div_val=4, + pre_lnorm=False, + n_layer=18, + tgt_len=128, + ext_len=0, + mem_len=1600, + clamp_len=1000, + same_length=True, + proj_share_all_but_first=True, + attn_type=0, + sample_softmax=-1, + adaptive=True, + tie_weight=True, + dropout=0.1, + dropatt=0.0, + untie_r=True, + init="normal", + init_range=0.01, + proj_init_std=0.01, + init_std=0.02, + **kwargs): + """Constructs TransfoXLConfig. + """ + super(TransfoXLConfig, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.n_token = vocab_size_or_config_json_file + self.cutoffs = [] + self.cutoffs.extend(cutoffs) + self.tie_weight = tie_weight + if proj_share_all_but_first: + self.tie_projs = [False] + [True] * len(self.cutoffs) + else: + self.tie_projs = [False] + [False] * len(self.cutoffs) + self.d_model = d_model + self.d_embed = d_embed + self.d_head = d_head + self.d_inner = d_inner + self.div_val = div_val + self.pre_lnorm = pre_lnorm + self.n_layer = n_layer + self.n_head = n_head + self.tgt_len = tgt_len + self.ext_len = ext_len + self.mem_len = mem_len + self.same_length = same_length + self.attn_type = attn_type + self.clamp_len = clamp_len + self.sample_softmax = sample_softmax + self.adaptive = adaptive + self.dropout = dropout + self.dropatt = dropatt + self.untie_r = untie_r + self.init = init + self.init_range = init_range + self.proj_init_std = proj_init_std + self.init_std = init_std + else: + raise ValueError("First argument must be either a vocabulary size (int)" + " or the path to a pretrained model config file (str)") + + @property + def max_position_embeddings(self): + return self.tgt_len + self.ext_len + self.mem_len + + @property + def vocab_size(self): + return self.n_token + + @vocab_size.setter + def vocab_size(self, value): + self.n_token = value + + @property + def hidden_size(self): + return self.d_model + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/pytorch_transformers/configuration_utils.py b/pytorch_transformers/configuration_utils.py new file mode 100644 index 0000000000..550b47fab8 --- /dev/null +++ b/pytorch_transformers/configuration_utils.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Configuration base class and utilities.""" + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import copy +import json +import logging +import os +from io import open + +from .file_utils import cached_path + +logger = logging.getLogger(__name__) + +CONFIG_NAME = "config.json" + +class PretrainedConfig(object): + r""" Base class for all configuration classes. + Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. + + Note: + A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. + It only affects the model's configuration. + + Class attributes (overridden by derived classes): + - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. + + Parameters: + ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. + ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) + ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. + ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. + ``torchscript``: string, default `False`. Is the model used with Torchscript. + """ + pretrained_config_archive_map = {} + + def __init__(self, **kwargs): + self.finetuning_task = kwargs.pop('finetuning_task', None) + self.num_labels = kwargs.pop('num_labels', 2) + self.output_attentions = kwargs.pop('output_attentions', False) + self.output_hidden_states = kwargs.pop('output_hidden_states', False) + self.torchscript = kwargs.pop('torchscript', False) + self.pruned_heads = kwargs.pop('pruned_heads', {}) + + def save_pretrained(self, save_directory): + """ Save a configuration object to the directory `save_directory`, so that it + can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method. + """ + assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): + r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. + + Parameters: + pretrained_model_name_or_path: either: + + - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. + - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. + - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. + + cache_dir: (`optional`) string: + Path to a directory in which a downloaded pre-trained model + configuration should be cached if the standard cache should not be used. + + kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. + + - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. + - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. + + force_download: (`optional`) boolean, default False: + Force to (re-)download the model weights and configuration files and override the cached versions if they exists. + + proxies: (`optional`) dict, default None: + A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. + The proxies are used on each request. + + return_unused_kwargs: (`optional`) bool: + + - If False, then this function returns just the final configuration object. + - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. + + Examples:: + + # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a + # derived class: BertConfig + config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. + config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` + config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') + config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) + assert config.output_attention == True + config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, + foo=False, return_unused_kwargs=True) + assert config.output_attention == True + assert unused_kwargs == {'foo': False} + + """ + cache_dir = kwargs.pop('cache_dir', None) + force_download = kwargs.pop('force_download', False) + proxies = kwargs.pop('proxies', None) + return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) + + if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] + elif os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) + else: + config_file = pretrained_model_name_or_path + # redirect to the cache, if necessary + try: + resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) + except EnvironmentError as e: + if pretrained_model_name_or_path in cls.pretrained_config_archive_map: + logger.error( + "Couldn't reach server at '{}' to download pretrained model configuration file.".format( + config_file)) + else: + logger.error( + "Model name '{}' was not found in model name list ({}). " + "We assumed '{}' was a path or url but couldn't find any file " + "associated to this path or url.".format( + pretrained_model_name_or_path, + ', '.join(cls.pretrained_config_archive_map.keys()), + config_file)) + raise e + if resolved_config_file == config_file: + logger.info("loading configuration file {}".format(config_file)) + else: + logger.info("loading configuration file {} from cache at {}".format( + config_file, resolved_config_file)) + + # Load config + config = cls.from_json_file(resolved_config_file) + + if hasattr(config, 'pruned_heads'): + config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) + + # Update config with kwargs if needed + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info("Model config %s", config) + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_dict(cls, json_object): + """Constructs a `Config` from a Python dictionary of parameters.""" + config = cls(vocab_size_or_config_json_file=-1) + for key, value in json_object.items(): + config.__dict__[key] = value + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `BertConfig` from a json file of parameters.""" + with open(json_file, "r", encoding='utf-8') as reader: + text = reader.read() + return cls.from_dict(json.loads(text)) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return str(self.to_json_string()) + + def to_dict(self): + """Serializes this instance to a Python dictionary.""" + output = copy.deepcopy(self.__dict__) + return output + + def to_json_string(self): + """Serializes this instance to a JSON string.""" + return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path): + """ Save this instance to a json file.""" + with open(json_file_path, "w", encoding='utf-8') as writer: + writer.write(self.to_json_string()) diff --git a/pytorch_transformers/configuration_xlm.py b/pytorch_transformers/configuration_xlm.py new file mode 100644 index 0000000000..ab251c8939 --- /dev/null +++ b/pytorch_transformers/configuration_xlm.py @@ -0,0 +1,184 @@ +# coding=utf-8 +# Copyright 2019-present, Facebook, Inc and the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" XLM configuration """ +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", + 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", + 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", + 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", + 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", + 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", + 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", + 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", + 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", + 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", +} + + +class XLMConfig(PretrainedConfig): + """Configuration class to store the configuration of a `XLMModel`. + + Args: + vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`. + d_model: Size of the encoder layers and the pooler layer. + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + d_inner: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + ff_activation: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + untie_r: untie relative position biases + attn_type: 'bi' for XLM, 'uni' for Transformer-XL + + dropout: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + dropatt: The dropout ratio for the attention + probabilities. + max_position_embeddings: The maximum sequence length that this model might + ever be used with. Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + + dropout: float, dropout rate. + dropatt: float, dropout rate on attention probabilities. + init: str, the initialization scheme, either "normal" or "uniform". + init_range: float, initialize the parameters with a uniform distribution + in [-init_range, init_range]. Only effective when init="uniform". + init_std: float, initialize the parameters with a normal distribution + with mean 0 and stddev init_std. Only effective when init="normal". + mem_len: int, the number of tokens to cache. + reuse_len: int, the number of tokens in the currect batch to be cached + and reused in the future. + bi_data: bool, whether to use bidirectional input pipeline. + Usually set to True during pretraining and False during finetuning. + clamp_len: int, clamp all relative distances larger than clamp_len. + -1 means no clamping. + same_length: bool, whether to use the same attention length for each token. + """ + pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=30145, + emb_dim=2048, + n_layers=12, + n_heads=16, + dropout=0.1, + attention_dropout=0.1, + gelu_activation=True, + sinusoidal_embeddings=False, + causal=False, + asm=False, + n_langs=1, + use_lang_emb=True, + max_position_embeddings=512, + embed_init_std=2048 ** -0.5, + layer_norm_eps=1e-12, + init_std=0.02, + bos_index=0, + eos_index=1, + pad_index=2, + unk_index=3, + mask_index=5, + is_encoder=True, + + finetuning_task=None, + num_labels=2, + summary_type='first', + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + start_n_top=5, + end_n_top=5, + **kwargs): + """Constructs XLMConfig. + """ + super(XLMConfig, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.n_words = vocab_size_or_config_json_file + self.emb_dim = emb_dim + self.n_layers = n_layers + self.n_heads = n_heads + self.dropout = dropout + self.attention_dropout = attention_dropout + self.gelu_activation = gelu_activation + self.sinusoidal_embeddings = sinusoidal_embeddings + self.causal = causal + self.asm = asm + self.n_langs = n_langs + self.use_lang_emb = use_lang_emb + self.layer_norm_eps = layer_norm_eps + self.bos_index = bos_index + self.eos_index = eos_index + self.pad_index = pad_index + self.unk_index = unk_index + self.mask_index = mask_index + self.is_encoder = is_encoder + self.max_position_embeddings = max_position_embeddings + self.embed_init_std = embed_init_std + self.init_std = init_std + self.finetuning_task = finetuning_task + self.num_labels = num_labels + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_proj_to_labels = summary_proj_to_labels + self.summary_first_dropout = summary_first_dropout + self.start_n_top = start_n_top + self.end_n_top = end_n_top + else: + raise ValueError("First argument must be either a vocabulary size (int)" + " or the path to a pretrained model config file (str)") + + @property + def vocab_size(self): + return self.n_words + + @vocab_size.setter + def vocab_size(self, value): + self.n_words = value + + @property + def hidden_size(self): + return self.emb_dim + + @property + def num_attention_heads(self): + return self.n_heads + + @property + def num_hidden_layers(self): + return self.n_layers diff --git a/pytorch_transformers/configuration_xlnet.py b/pytorch_transformers/configuration_xlnet.py new file mode 100644 index 0000000000..204d44aa72 --- /dev/null +++ b/pytorch_transformers/configuration_xlnet.py @@ -0,0 +1,172 @@ +# coding=utf-8 +# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" XLNet configuration """ +from __future__ import absolute_import, division, print_function, unicode_literals + +import json +import logging +import sys +from io import open + +from .configuration_utils import PretrainedConfig + +logger = logging.getLogger(__name__) + +XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { + 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", + 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", +} + + +class XLNetConfig(PretrainedConfig): + """Configuration class to store the configuration of a ``XLNetModel``. + + Args: + vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. + d_model: Size of the encoder layers and the pooler layer. + n_layer: Number of hidden layers in the Transformer encoder. + n_head: Number of attention heads for each attention layer in + the Transformer encoder. + d_inner: The size of the "intermediate" (i.e., feed-forward) + layer in the Transformer encoder. + ff_activation: The non-linear activation function (function or string) in the + encoder and pooler. If string, "gelu", "relu" and "swish" are supported. + untie_r: untie relative position biases + attn_type: 'bi' for XLNet, 'uni' for Transformer-XL + + dropout: The dropout probabilitiy for all fully connected + layers in the embeddings, encoder, and pooler. + dropatt: The dropout ratio for the attention + probabilities. + initializer_range: The sttdev of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_eps: The epsilon used by LayerNorm. + + dropout: float, dropout rate. + dropatt: float, dropout rate on attention probabilities. + init: str, the initialization scheme, either "normal" or "uniform". + init_range: float, initialize the parameters with a uniform distribution + in [-init_range, init_range]. Only effective when init="uniform". + init_std: float, initialize the parameters with a normal distribution + with mean 0 and stddev init_std. Only effective when init="normal". + mem_len: int, the number of tokens to cache. + reuse_len: int, the number of tokens in the currect batch to be cached + and reused in the future. + bi_data: bool, whether to use bidirectional input pipeline. + Usually set to True during pretraining and False during finetuning. + clamp_len: int, clamp all relative distances larger than clamp_len. + -1 means no clamping. + same_length: bool, whether to use the same attention length for each token. + finetuning_task: name of the glue task on which the model was fine-tuned if any + """ + pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP + + def __init__(self, + vocab_size_or_config_json_file=32000, + d_model=1024, + n_layer=24, + n_head=16, + d_inner=4096, + ff_activation="gelu", + untie_r=True, + attn_type="bi", + + initializer_range=0.02, + layer_norm_eps=1e-12, + + dropout=0.1, + mem_len=None, + reuse_len=None, + bi_data=False, + clamp_len=-1, + same_length=False, + + finetuning_task=None, + num_labels=2, + summary_type='last', + summary_use_proj=True, + summary_activation='tanh', + summary_last_dropout=0.1, + start_n_top=5, + end_n_top=5, + **kwargs): + """Constructs XLNetConfig. + """ + super(XLNetConfig, self).__init__(**kwargs) + + if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 + and isinstance(vocab_size_or_config_json_file, unicode)): + with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: + json_config = json.loads(reader.read()) + for key, value in json_config.items(): + self.__dict__[key] = value + elif isinstance(vocab_size_or_config_json_file, int): + self.n_token = vocab_size_or_config_json_file + self.d_model = d_model + self.n_layer = n_layer + self.n_head = n_head + assert d_model % n_head == 0 + self.d_head = d_model // n_head + self.ff_activation = ff_activation + self.d_inner = d_inner + self.untie_r = untie_r + self.attn_type = attn_type + + self.initializer_range = initializer_range + self.layer_norm_eps = layer_norm_eps + + self.dropout = dropout + self.mem_len = mem_len + self.reuse_len = reuse_len + self.bi_data = bi_data + self.clamp_len = clamp_len + self.same_length = same_length + + self.finetuning_task = finetuning_task + self.num_labels = num_labels + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_last_dropout = summary_last_dropout + self.start_n_top = start_n_top + self.end_n_top = end_n_top + else: + raise ValueError("First argument must be either a vocabulary size (int)" + " or the path to a pretrained model config file (str)") + + @property + def max_position_embeddings(self): + return -1 + + @property + def vocab_size(self): + return self.n_token + + @vocab_size.setter + def vocab_size(self, value): + self.n_token = value + + @property + def hidden_size(self): + return self.d_model + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer diff --git a/pytorch_transformers/file_utils.py b/pytorch_transformers/file_utils.py index f6f2151b12..37ebc57fb3 100644 --- a/pytorch_transformers/file_utils.py +++ b/pytorch_transformers/file_utils.py @@ -9,6 +9,7 @@ import sys import json import logging import os +import six import shutil import tempfile import fnmatch @@ -49,6 +50,29 @@ PYTORCH_TRANSFORMERS_CACHE = PYTORCH_PRETRAINED_BERT_CACHE # Kept for backward logger = logging.getLogger(__name__) # pylint: disable=invalid-name +if not six.PY2: + def add_start_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = ''.join(docstr) + fn.__doc__ + return fn + return docstring_decorator + + def add_end_docstrings(*docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + ''.join(docstr) + return fn + return docstring_decorator +else: + # Not possible to update class docstrings on python2 + def add_start_docstrings(*docstr): + def docstring_decorator(fn): + return fn + return docstring_decorator + + def add_end_docstrings(*docstr): + def docstring_decorator(fn): + return fn + return docstring_decorator def url_to_filename(url, etag=None): """ diff --git a/pytorch_transformers/modeling_auto.py b/pytorch_transformers/modeling_auto.py index 05ff5e5b33..31c8fafaa9 100644 --- a/pytorch_transformers/modeling_auto.py +++ b/pytorch_transformers/modeling_auto.py @@ -18,125 +18,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera import logging -from .modeling_bert import BertConfig, BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering -from .modeling_openai import OpenAIGPTConfig, OpenAIGPTModel, OpenAIGPTLMHeadModel -from .modeling_gpt2 import GPT2Config, GPT2Model, GPT2LMHeadModel -from .modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel -from .modeling_xlnet import XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering -from .modeling_xlm import XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering -from .modeling_roberta import RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification -from .modeling_distilbert import DistilBertConfig, DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification +from .modeling_bert import BertModel, BertForMaskedLM, BertForSequenceClassification, BertForQuestionAnswering +from .modeling_openai import OpenAIGPTModel, OpenAIGPTLMHeadModel +from .modeling_gpt2 import GPT2Model, GPT2LMHeadModel +from .modeling_transfo_xl import TransfoXLModel, TransfoXLLMHeadModel +from .modeling_xlnet import XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering +from .modeling_xlm import XLMModel, XLMWithLMHeadModel, XLMForSequenceClassification, XLMForQuestionAnswering +from .modeling_roberta import RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification +from .modeling_distilbert import DistilBertModel, DistilBertForQuestionAnswering, DistilBertForMaskedLM, DistilBertForSequenceClassification -from .modeling_utils import PreTrainedModel, SequenceSummary, add_start_docstrings +from .modeling_utils import PreTrainedModel, SequenceSummary + +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) -class AutoConfig(object): - r""":class:`~pytorch_transformers.AutoConfig` is a generic configuration class - that will be instantiated as one of the configuration classes of the library - when created with the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` - class method. - - The `from_pretrained()` method take care of returning the correct model class instance - using pattern matching on the `pretrained_model_name_or_path` string. - - The base model class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `distilbert`: DistilBertConfig (DistilBERT model) - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - contains `roberta`: RobertaConfig (RoBERTa model) - - This class cannot be instantiated using `__init__()` (throw an error). - """ - def __init__(self): - raise EnvironmentError("AutoConfig is designed to be instantiated " - "using the `AutoConfig.from_pretrained(pretrained_model_name_or_path)` method.") - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - r""" Instantiate a one of the configuration classes of the library - from a pre-trained model configuration. - - The configuration class to instantiate is selected as the first pattern matching - in the `pretrained_model_name_or_path` string (in the following order): - - contains `distilbert`: DistilBertConfig (DistilBERT model) - - contains `bert`: BertConfig (Bert model) - - contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model) - - contains `gpt2`: GPT2Config (OpenAI GPT-2 model) - - contains `transfo-xl`: TransfoXLConfig (Transformer-XL model) - - contains `xlnet`: XLNetConfig (XLNet model) - - contains `xlm`: XLMConfig (XLM model) - - contains `roberta`: RobertaConfig (RoBERTa model) - - Params: - pretrained_model_name_or_path: either: - - - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - - cache_dir: (`optional`) string: - Path to a directory in which a downloaded pre-trained model - configuration should be cached if the standard cache should not be used. - - kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. - - - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. - - force_download: (`optional`) boolean, default False: - Force to (re-)download the model weights and configuration files and override the cached versions if they exists. - - proxies: (`optional`) dict, default None: - A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. - The proxies are used on each request. - - return_unused_kwargs: (`optional`) bool: - - - If False, then this function returns just the final configuration object. - - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. - - Examples:: - - config = AutoConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. - config = AutoConfig.from_pretrained('./test/bert_saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` - config = AutoConfig.from_pretrained('./test/bert_saved_model/my_configuration.json') - config = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) - assert config.output_attention == True - config, unused_kwargs = AutoConfig.from_pretrained('bert-base-uncased', output_attention=True, - foo=False, return_unused_kwargs=True) - assert config.output_attention == True - assert unused_kwargs == {'foo': False} - - """ - if 'distilbert' in pretrained_model_name_or_path: - return DistilBertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'roberta' in pretrained_model_name_or_path: - return RobertaConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'bert' in pretrained_model_name_or_path: - return BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'openai-gpt' in pretrained_model_name_or_path: - return OpenAIGPTConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'gpt2' in pretrained_model_name_or_path: - return GPT2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'transfo-xl' in pretrained_model_name_or_path: - return TransfoXLConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'xlnet' in pretrained_model_name_or_path: - return XLNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - elif 'xlm' in pretrained_model_name_or_path: - return XLMConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) - - raise ValueError("Unrecognized model identifier in {}. Should contains one of " - "'bert', 'openai-gpt', 'gpt2', 'transfo-xl', 'xlnet', " - "'xlm', 'roberta'".format(pretrained_model_name_or_path)) - - class AutoModel(object): r""" :class:`~pytorch_transformers.AutoModel` is a generic model class diff --git a/pytorch_transformers/modeling_bert.py b/pytorch_transformers/modeling_bert.py index 5c71bedba9..c541d18da5 100644 --- a/pytorch_transformers/modeling_bert.py +++ b/pytorch_transformers/modeling_bert.py @@ -28,8 +28,9 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss, MSELoss -from .modeling_utils import (WEIGHTS_NAME, CONFIG_NAME, PretrainedConfig, PreTrainedModel, - prune_linear_layer, add_start_docstrings) +from .modeling_utils import PreTrainedModel, prune_linear_layer +from .configuration_bert import BertConfig +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -49,23 +50,6 @@ BERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin", } -BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json", - 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json", - 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json", - 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json", - 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json", - 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json", - 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json", - 'bert-base-german-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json", - 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json", - 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json", - 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json", - 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json", - 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json", -} - - def load_tf_weights_in_bert(model, config, tf_checkpoint_path): """ Load tf checkpoints in a pytorch model. """ @@ -149,77 +133,6 @@ def swish(x): ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} -class BertConfig(PretrainedConfig): - r""" - :class:`~pytorch_transformers.BertConfig` is the configuration class to store the configuration of a - `BertModel`. - - - Arguments: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`. - hidden_size: Size of the encoder layers and the pooler layer. - num_hidden_layers: Number of hidden layers in the Transformer encoder. - num_attention_heads: Number of attention heads for each attention layer in - the Transformer encoder. - intermediate_size: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - hidden_act: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - hidden_dropout_prob: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attention_probs_dropout_prob: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - type_vocab_size: The vocabulary size of the `token_type_ids` passed into - `BertModel`. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - layer_norm_eps: The epsilon used by LayerNorm. - """ - pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__(self, - vocab_size_or_config_json_file=30522, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=2, - initializer_range=0.02, - layer_norm_eps=1e-12, - **kwargs): - super(BertConfig, self).__init__(**kwargs) - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.hidden_size = hidden_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - self.hidden_act = hidden_act - self.intermediate_size = intermediate_size - self.hidden_dropout_prob = hidden_dropout_prob - self.attention_probs_dropout_prob = attention_probs_dropout_prob - self.max_position_embeddings = max_position_embeddings - self.type_vocab_size = type_vocab_size - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - else: - raise ValueError("First argument must be either a vocabulary size (int)" - " or the path to a pretrained model config file (str)") - - - try: from apex.normalization.fused_layer_norm import FusedLayerNorm as BertLayerNorm except (ImportError, AttributeError) as e: diff --git a/pytorch_transformers/modeling_distilbert.py b/pytorch_transformers/modeling_distilbert.py index 280270381c..29ecbfb846 100644 --- a/pytorch_transformers/modeling_distilbert.py +++ b/pytorch_transformers/modeling_distilbert.py @@ -31,7 +31,9 @@ import numpy as np import torch import torch.nn as nn -from pytorch_transformers.modeling_utils import PretrainedConfig, PreTrainedModel, add_start_docstrings, prune_linear_layer +from .modeling_utils import PreTrainedModel, prune_linear_layer +from .configuration_distilbert import DistilBertConfig +from .file_utils import add_start_docstrings import logging logger = logging.getLogger(__name__) @@ -42,69 +44,6 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = { 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-pytorch_model.bin" } -DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'distilbert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-config.json", - 'distilbert-base-uncased-distilled-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/distilbert-base-uncased-distilled-squad-config.json" -} - - -class DistilBertConfig(PretrainedConfig): - pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__(self, - vocab_size_or_config_json_file=30522, - max_position_embeddings=512, - sinusoidal_pos_embds=True, - n_layers=6, - n_heads=12, - dim=768, - hidden_dim=4*768, - dropout=0.1, - attention_dropout=0.1, - activation='gelu', - initializer_range=0.02, - tie_weights_=True, - qa_dropout=0.1, - seq_classif_dropout=0.2, - **kwargs): - super(DistilBertConfig, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.max_position_embeddings = max_position_embeddings - self.sinusoidal_pos_embds = sinusoidal_pos_embds - self.n_layers = n_layers - self.n_heads = n_heads - self.dim = dim - self.hidden_dim = hidden_dim - self.dropout = dropout - self.attention_dropout = attention_dropout - self.activation = activation - self.initializer_range = initializer_range - self.tie_weights_ = tie_weights_ - self.qa_dropout = qa_dropout - self.seq_classif_dropout = seq_classif_dropout - else: - raise ValueError("First argument must be either a vocabulary size (int)" - " or the path to a pretrained model config file (str)") - @property - def hidden_size(self): - return self.dim - - @property - def num_attention_heads(self): - return self.n_heads - - @property - def num_hidden_layers(self): - return self.n_layers - ### UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE ### def gelu(x): diff --git a/pytorch_transformers/modeling_gpt2.py b/pytorch_transformers/modeling_gpt2.py index d16448beaa..4268641187 100644 --- a/pytorch_transformers/modeling_gpt2.py +++ b/pytorch_transformers/modeling_gpt2.py @@ -30,19 +30,15 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, - PreTrainedModel, prune_conv1d_layer, SequenceSummary, - add_start_docstrings) -from .modeling_bert import BertLayerNorm as LayerNorm +from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary +from .configuration_gpt2 import GPT2Config +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin", "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin", "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"} -GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json", - "gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json", - "gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"} def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): """ Load tf checkpoints in a pytorch model @@ -102,120 +98,6 @@ def gelu(x): return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) -class GPT2Config(PretrainedConfig): - """Configuration class to store the configuration of a `GPT2Model`. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. - n_positions: Number of positional embeddings. - n_ctx: Size of the causal mask (usually same as n_positions). - n_embd: Dimensionality of the embeddings and hidden states. - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - layer_norm_epsilon: epsilon to use in the layer norm layers - resid_pdrop: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attn_pdrop: The dropout ratio for the attention - probabilities. - embd_pdrop: The dropout ratio for the embeddings. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__( - self, - vocab_size_or_config_json_file=50257, - n_positions=1024, - n_ctx=1024, - n_embd=768, - n_layer=12, - n_head=12, - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - - num_labels=1, - summary_type='cls_index', - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - **kwargs - ): - """Constructs GPT2Config. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `GPT2Model` or a configuration json file. - n_positions: Number of positional embeddings. - n_ctx: Size of the causal mask (usually same as n_positions). - n_embd: Dimensionality of the embeddings and hidden states. - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - layer_norm_epsilon: epsilon to use in the layer norm layers - resid_pdrop: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attn_pdrop: The dropout ratio for the attention - probabilities. - embd_pdrop: The dropout ratio for the embeddings. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - """ - super(GPT2Config, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.n_ctx = n_ctx - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.resid_pdrop = resid_pdrop - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - - self.num_labels = num_labels - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_first_dropout = summary_first_dropout - self.summary_proj_to_labels = summary_proj_to_labels - else: - raise ValueError( - "First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)" - ) - - @property - def max_position_embeddings(self): - return self.n_positions - - @property - def hidden_size(self): - return self.n_embd - - @property - def num_attention_heads(self): - return self.n_head - - @property - def num_hidden_layers(self): - return self.n_layer - - - class Attention(nn.Module): def __init__(self, nx, n_ctx, config, scale=False): super(Attention, self).__init__() @@ -336,9 +218,9 @@ class Block(nn.Module): def __init__(self, n_ctx, config, scale=False): super(Block, self).__init__() nx = config.n_embd - self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn = Attention(nx, n_ctx, config, scale) - self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): @@ -377,7 +259,7 @@ class GPT2PreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: module.bias.data.zero_() - elif isinstance(module, LayerNorm): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) @@ -469,7 +351,7 @@ class GPT2Model(GPT2PreTrainedModel): self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) - self.ln_f = LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) self.init_weights() diff --git a/pytorch_transformers/modeling_openai.py b/pytorch_transformers/modeling_openai.py index 4fbec7a768..9936e72030 100644 --- a/pytorch_transformers/modeling_openai.py +++ b/pytorch_transformers/modeling_openai.py @@ -30,15 +30,13 @@ import torch.nn as nn from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .modeling_utils import (Conv1D, CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, - PreTrainedModel, prune_conv1d_layer, SequenceSummary, - add_start_docstrings) -from .modeling_bert import BertLayerNorm as LayerNorm +from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary +from .configuration_openai import OpenAIGPTConfig +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-pytorch_model.bin"} -OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {"openai-gpt": "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-config.json"} def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): @@ -127,111 +125,6 @@ def swish(x): ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu} -class OpenAIGPTConfig(PretrainedConfig): - """ - Configuration class to store the configuration of a `OpenAIGPTModel`. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `OpenAIGPTModel` or a configuration json file. - n_special: The number of special tokens to learn during fine-tuning ('[SEP]', '[CLF]', ...) - n_positions: Number of positional embeddings. - n_ctx: Size of the causal mask (usually same as n_positions). - n_embd: Dimensionality of the embeddings and hidden states. - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - afn: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - resid_pdrop: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - attn_pdrop: The dropout ratio for the attention - probabilities. - embd_pdrop: The dropout ratio for the embeddings. - layer_norm_epsilon: epsilon to use in the layer norm layers - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - predict_special_tokens: should we predict special tokens (when the model has a LM head) - """ - pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__( - self, - vocab_size_or_config_json_file=40478, - n_positions=512, - n_ctx=512, - n_embd=768, - n_layer=12, - n_head=12, - afn="gelu", - resid_pdrop=0.1, - embd_pdrop=0.1, - attn_pdrop=0.1, - layer_norm_epsilon=1e-5, - initializer_range=0.02, - predict_special_tokens=True, - - num_labels=1, - summary_type='cls_index', - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - **kwargs - ): - """Constructs OpenAIGPTConfig. - """ - super(OpenAIGPTConfig, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.vocab_size = vocab_size_or_config_json_file - self.n_ctx = n_ctx - self.n_positions = n_positions - self.n_embd = n_embd - self.n_layer = n_layer - self.n_head = n_head - self.afn = afn - self.resid_pdrop = resid_pdrop - self.embd_pdrop = embd_pdrop - self.attn_pdrop = attn_pdrop - self.layer_norm_epsilon = layer_norm_epsilon - self.initializer_range = initializer_range - self.predict_special_tokens = predict_special_tokens - - self.num_labels = num_labels - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_first_dropout = summary_first_dropout - self.summary_proj_to_labels = summary_proj_to_labels - else: - raise ValueError( - "First argument must be either a vocabulary size (int)" - "or the path to a pretrained model config file (str)" - ) - - @property - def max_position_embeddings(self): - return self.n_positions - - @property - def hidden_size(self): - return self.n_embd - - @property - def num_attention_heads(self): - return self.n_head - - @property - def num_hidden_layers(self): - return self.n_layer - - class Attention(nn.Module): def __init__(self, nx, n_ctx, config, scale=False): super(Attention, self).__init__() @@ -346,9 +239,9 @@ class Block(nn.Module): super(Block, self).__init__() nx = config.n_embd self.attn = Attention(nx, n_ctx, config, scale) - self.ln_1 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = MLP(4 * nx, config) - self.ln_2 = LayerNorm(nx, eps=config.layer_norm_epsilon) + self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) def forward(self, x, attention_mask=None, head_mask=None): attn_outputs = self.attn(x, attention_mask=attention_mask, head_mask=head_mask) @@ -380,7 +273,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None: module.bias.data.zero_() - elif isinstance(module, LayerNorm): + elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) diff --git a/pytorch_transformers/modeling_roberta.py b/pytorch_transformers/modeling_roberta.py index 0694e76415..41027330a8 100644 --- a/pytorch_transformers/modeling_roberta.py +++ b/pytorch_transformers/modeling_roberta.py @@ -22,14 +22,11 @@ import logging import torch import torch.nn as nn -import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss -from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings, - BertLayerNorm, BertModel, - BertPreTrainedModel, gelu) - -from pytorch_transformers.modeling_utils import add_start_docstrings +from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu +from .configuration_roberta import RobertaConfig +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -39,13 +36,6 @@ ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = { 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin", } -ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json", - 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json", - 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json", -} - - class RobertaEmbeddings(BertEmbeddings): """ Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. @@ -66,10 +56,6 @@ class RobertaEmbeddings(BertEmbeddings): position_ids=position_ids) -class RobertaConfig(BertConfig): - pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP - - ROBERTA_START_DOCSTRING = r""" The RoBERTa model was proposed in `RoBERTa: A Robustly Optimized BERT Pretraining Approach`_ by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, diff --git a/pytorch_transformers/modeling_transfo_xl.py b/pytorch_transformers/modeling_transfo_xl.py index 3fff0c552e..7ad9d10891 100644 --- a/pytorch_transformers/modeling_transfo_xl.py +++ b/pytorch_transformers/modeling_transfo_xl.py @@ -34,18 +34,16 @@ import torch.nn.functional as F from torch.nn import CrossEntropyLoss from torch.nn.parameter import Parameter -from .modeling_bert import BertLayerNorm as LayerNorm +from .modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary +from .configuration_transfo_xl import TransfoXLConfig from .modeling_transfo_xl_utilities import ProjectedAdaptiveLogSoftmax, sample_logits -from .modeling_utils import (PretrainedConfig, PreTrainedModel, add_start_docstrings) +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = { 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-pytorch_model.bin", } -TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-config.json", -} def build_tf_to_pytorch_map(model, config): """ A map of modules from TF to PyTorch. @@ -175,143 +173,6 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path): return model -class TransfoXLConfig(PretrainedConfig): - """Configuration class to store the configuration of a `TransfoXLModel`. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `TransfoXLModel` or a configuration json file. - cutoffs: cutoffs for the adaptive softmax - d_model: Dimensionality of the model's hidden states. - d_embed: Dimensionality of the embeddings - d_head: Dimensionality of the model's heads. - div_val: divident value for adapative input and softmax - pre_lnorm: apply LayerNorm to the input instead of the output - d_inner: Inner dimension in FF - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - tgt_len: number of tokens to predict - ext_len: length of the extended context - mem_len: length of the retained previous heads - same_length: use the same attn length for all tokens - proj_share_all_but_first: True to share all but first projs, False not to share. - attn_type: attention type. 0 for Transformer-XL, 1 for Shaw et al, 2 for Vaswani et al, 3 for Al Rfou et al. - clamp_len: use the same pos embeddings after clamp_len - sample_softmax: number of samples in sampled softmax - adaptive: use adaptive softmax - tie_weight: tie the word embedding and softmax weights - dropout: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - dropatt: The dropout ratio for the attention probabilities. - untie_r: untie relative position biases - embd_pdrop: The dropout ratio for the embeddings. - init: parameter initializer to use - init_range: parameters initialized by U(-init_range, init_range). - proj_init_std: parameters initialized by N(0, init_std) - init_std: parameters initialized by N(0, init_std) - """ - pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__(self, - vocab_size_or_config_json_file=267735, - cutoffs=[20000, 40000, 200000], - d_model=1024, - d_embed=1024, - n_head=16, - d_head=64, - d_inner=4096, - div_val=4, - pre_lnorm=False, - n_layer=18, - tgt_len=128, - ext_len=0, - mem_len=1600, - clamp_len=1000, - same_length=True, - proj_share_all_but_first=True, - attn_type=0, - sample_softmax=-1, - adaptive=True, - tie_weight=True, - dropout=0.1, - dropatt=0.0, - untie_r=True, - init="normal", - init_range=0.01, - proj_init_std=0.01, - init_std=0.02, - **kwargs): - """Constructs TransfoXLConfig. - """ - super(TransfoXLConfig, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.n_token = vocab_size_or_config_json_file - self.cutoffs = [] - self.cutoffs.extend(cutoffs) - self.tie_weight = tie_weight - if proj_share_all_but_first: - self.tie_projs = [False] + [True] * len(self.cutoffs) - else: - self.tie_projs = [False] + [False] * len(self.cutoffs) - self.d_model = d_model - self.d_embed = d_embed - self.d_head = d_head - self.d_inner = d_inner - self.div_val = div_val - self.pre_lnorm = pre_lnorm - self.n_layer = n_layer - self.n_head = n_head - self.tgt_len = tgt_len - self.ext_len = ext_len - self.mem_len = mem_len - self.same_length = same_length - self.attn_type = attn_type - self.clamp_len = clamp_len - self.sample_softmax = sample_softmax - self.adaptive = adaptive - self.dropout = dropout - self.dropatt = dropatt - self.untie_r = untie_r - self.init = init - self.init_range = init_range - self.proj_init_std = proj_init_std - self.init_std = init_std - else: - raise ValueError("First argument must be either a vocabulary size (int)" - " or the path to a pretrained model config file (str)") - - @property - def max_position_embeddings(self): - return self.tgt_len + self.ext_len + self.mem_len - - @property - def vocab_size(self): - return self.n_token - - @vocab_size.setter - def vocab_size(self, value): - self.n_token = value - - @property - def hidden_size(self): - return self.d_model - - @property - def num_attention_heads(self): - return self.n_head - - @property - def num_hidden_layers(self): - return self.n_layer - - class PositionalEmbedding(nn.Module): def __init__(self, demb): super(PositionalEmbedding, self).__init__() @@ -347,7 +208,7 @@ class PositionwiseFF(nn.Module): nn.Dropout(dropout), ) - self.layer_norm = LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) self.pre_lnorm = pre_lnorm @@ -387,7 +248,7 @@ class MultiHeadAttn(nn.Module): self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) - self.layer_norm = LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) @@ -477,7 +338,7 @@ class RelMultiHeadAttn(nn.Module): self.dropatt = nn.Dropout(dropatt) self.o_net = nn.Linear(n_head * d_head, d_model, bias=False) - self.layer_norm = LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) self.scale = 1 / (d_head ** 0.5) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index d482addb83..48420f6d07 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -30,11 +30,11 @@ from torch import nn from torch.nn import CrossEntropyLoss from torch.nn import functional as F +from .configuration_utils import PretrainedConfig from .file_utils import cached_path logger = logging.getLogger(__name__) -CONFIG_NAME = "config.json" WEIGHTS_NAME = "pytorch_model.bin" TF_WEIGHTS_NAME = 'model.ckpt' @@ -52,209 +52,6 @@ except ImportError: def forward(self, input): return input - -if not six.PY2: - def add_start_docstrings(*docstr): - def docstring_decorator(fn): - fn.__doc__ = ''.join(docstr) + fn.__doc__ - return fn - return docstring_decorator - - def add_end_docstrings(*docstr): - def docstring_decorator(fn): - fn.__doc__ = fn.__doc__ + ''.join(docstr) - return fn - return docstring_decorator -else: - # Not possible to update class docstrings on python2 - def add_start_docstrings(*docstr): - def docstring_decorator(fn): - return fn - return docstring_decorator - - def add_end_docstrings(*docstr): - def docstring_decorator(fn): - return fn - return docstring_decorator - - -class PretrainedConfig(object): - r""" Base class for all configuration classes. - Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. - - Note: - A configuration file can be loaded and saved to disk. Loading the configuration file and using this file to initialize a model does **not** load the model weights. - It only affects the model's configuration. - - Class attributes (overridden by derived classes): - - ``pretrained_config_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained model configurations as values. - - Parameters: - ``finetuning_task``: string, default `None`. Name of the task used to fine-tune the model. This can be used when converting from an original (TensorFlow or PyTorch) checkpoint. - ``num_labels``: integer, default `2`. Number of classes to use when the model is a classification model (sequences/tokens) - ``output_attentions``: boolean, default `False`. Should the model returns attentions weights. - ``output_hidden_states``: string, default `False`. Should the model returns all hidden-states. - ``torchscript``: string, default `False`. Is the model used with Torchscript. - """ - pretrained_config_archive_map = {} - - def __init__(self, **kwargs): - self.finetuning_task = kwargs.pop('finetuning_task', None) - self.num_labels = kwargs.pop('num_labels', 2) - self.output_attentions = kwargs.pop('output_attentions', False) - self.output_hidden_states = kwargs.pop('output_hidden_states', False) - self.torchscript = kwargs.pop('torchscript', False) - self.pruned_heads = kwargs.pop('pruned_heads', {}) - - def save_pretrained(self, save_directory): - """ Save a configuration object to the directory `save_directory`, so that it - can be re-loaded using the :func:`~pytorch_transformers.PretrainedConfig.from_pretrained` class method. - """ - assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved" - - # If we save using the predefined names, we can load using `from_pretrained` - output_config_file = os.path.join(save_directory, CONFIG_NAME) - - self.to_json_file(output_config_file) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): - r""" Instantiate a :class:`~pytorch_transformers.PretrainedConfig` (or a derived class) from a pre-trained model configuration. - - Parameters: - pretrained_model_name_or_path: either: - - - a string with the `shortcut name` of a pre-trained model configuration to load from cache or download, e.g.: ``bert-base-uncased``. - - a path to a `directory` containing a configuration file saved using the :func:`~pytorch_transformers.PretrainedConfig.save_pretrained` method, e.g.: ``./my_model_directory/``. - - a path or url to a saved configuration JSON `file`, e.g.: ``./my_model_directory/configuration.json``. - - cache_dir: (`optional`) string: - Path to a directory in which a downloaded pre-trained model - configuration should be cached if the standard cache should not be used. - - kwargs: (`optional`) dict: key/value pairs with which to update the configuration object after loading. - - - The values in kwargs of any keys which are configuration attributes will be used to override the loaded values. - - Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled by the `return_unused_kwargs` keyword parameter. - - force_download: (`optional`) boolean, default False: - Force to (re-)download the model weights and configuration files and override the cached versions if they exists. - - proxies: (`optional`) dict, default None: - A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. - The proxies are used on each request. - - return_unused_kwargs: (`optional`) bool: - - - If False, then this function returns just the final configuration object. - - If True, then this functions returns a tuple `(config, unused_kwargs)` where `unused_kwargs` is a dictionary consisting of the key/value pairs whose keys are not configuration attributes: ie the part of kwargs which has not been used to update `config` and is otherwise ignored. - - Examples:: - - # We can't instantiate directly the base class `PretrainedConfig` so let's show the examples on a - # derived class: BertConfig - config = BertConfig.from_pretrained('bert-base-uncased') # Download configuration from S3 and cache. - config = BertConfig.from_pretrained('./test/saved_model/') # E.g. config (or model) was saved using `save_pretrained('./test/saved_model/')` - config = BertConfig.from_pretrained('./test/saved_model/my_configuration.json') - config = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, foo=False) - assert config.output_attention == True - config, unused_kwargs = BertConfig.from_pretrained('bert-base-uncased', output_attention=True, - foo=False, return_unused_kwargs=True) - assert config.output_attention == True - assert unused_kwargs == {'foo': False} - - """ - cache_dir = kwargs.pop('cache_dir', None) - force_download = kwargs.pop('force_download', False) - proxies = kwargs.pop('proxies', None) - return_unused_kwargs = kwargs.pop('return_unused_kwargs', False) - - if pretrained_model_name_or_path in cls.pretrained_config_archive_map: - config_file = cls.pretrained_config_archive_map[pretrained_model_name_or_path] - elif os.path.isdir(pretrained_model_name_or_path): - config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) - else: - config_file = pretrained_model_name_or_path - # redirect to the cache, if necessary - try: - resolved_config_file = cached_path(config_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies) - except EnvironmentError as e: - if pretrained_model_name_or_path in cls.pretrained_config_archive_map: - logger.error( - "Couldn't reach server at '{}' to download pretrained model configuration file.".format( - config_file)) - else: - logger.error( - "Model name '{}' was not found in model name list ({}). " - "We assumed '{}' was a path or url but couldn't find any file " - "associated to this path or url.".format( - pretrained_model_name_or_path, - ', '.join(cls.pretrained_config_archive_map.keys()), - config_file)) - raise e - if resolved_config_file == config_file: - logger.info("loading configuration file {}".format(config_file)) - else: - logger.info("loading configuration file {} from cache at {}".format( - config_file, resolved_config_file)) - - # Load config - config = cls.from_json_file(resolved_config_file) - - if hasattr(config, 'pruned_heads'): - config.pruned_heads = dict((int(key), set(value)) for key, value in config.pruned_heads.items()) - - # Update config with kwargs if needed - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - - logger.info("Model config %s", config) - if return_unused_kwargs: - return config, kwargs - else: - return config - - @classmethod - def from_dict(cls, json_object): - """Constructs a `Config` from a Python dictionary of parameters.""" - config = cls(vocab_size_or_config_json_file=-1) - for key, value in json_object.items(): - config.__dict__[key] = value - return config - - @classmethod - def from_json_file(cls, json_file): - """Constructs a `BertConfig` from a json file of parameters.""" - with open(json_file, "r", encoding='utf-8') as reader: - text = reader.read() - return cls.from_dict(json.loads(text)) - - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __repr__(self): - return str(self.to_json_string()) - - def to_dict(self): - """Serializes this instance to a Python dictionary.""" - output = copy.deepcopy(self.__dict__) - return output - - def to_json_string(self): - """Serializes this instance to a JSON string.""" - return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" - - def to_json_file(self, json_file_path): - """ Save this instance to a json file.""" - with open(json_file_path, "w", encoding='utf-8') as writer: - writer.write(self.to_json_string()) - - class PreTrainedModel(nn.Module): r""" Base class for all models. diff --git a/pytorch_transformers/modeling_xlm.py b/pytorch_transformers/modeling_xlm.py index 7e13f10495..34406f5b61 100644 --- a/pytorch_transformers/modeling_xlm.py +++ b/pytorch_transformers/modeling_xlm.py @@ -16,11 +16,8 @@ """ from __future__ import absolute_import, division, print_function, unicode_literals -import json import logging import math -import sys -from io import open import itertools import numpy as np @@ -30,8 +27,9 @@ from torch import nn from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss -from .modeling_utils import (PretrainedConfig, PreTrainedModel, add_start_docstrings, - prune_linear_layer, SequenceSummary, SQuADHead) +from .modeling_utils import PreTrainedModel, prune_linear_layer, SequenceSummary, SQuADHead +from .configuration_xlm import XLMConfig +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -47,164 +45,6 @@ XLM_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-pytorch_model.bin", 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-pytorch_model.bin", } -XLM_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'xlm-mlm-en-2048': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-en-2048-config.json", - 'xlm-mlm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-ende-1024-config.json", - 'xlm-mlm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enfr-1024-config.json", - 'xlm-mlm-enro-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-enro-1024-config.json", - 'xlm-mlm-tlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-tlm-xnli15-1024-config.json", - 'xlm-mlm-xnli15-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-xnli15-1024-config.json", - 'xlm-clm-enfr-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-enfr-1024-config.json", - 'xlm-clm-ende-1024': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-clm-ende-1024-config.json", - 'xlm-mlm-17-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-17-1280-config.json", - 'xlm-mlm-100-1280': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-mlm-100-1280-config.json", -} - - -class XLMConfig(PretrainedConfig): - """Configuration class to store the configuration of a `XLMModel`. - - Args: - vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `XLMModel`. - d_model: Size of the encoder layers and the pooler layer. - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - d_inner: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - ff_activation: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - untie_r: untie relative position biases - attn_type: 'bi' for XLM, 'uni' for Transformer-XL - - dropout: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - dropatt: The dropout ratio for the attention - probabilities. - max_position_embeddings: The maximum sequence length that this model might - ever be used with. Typically set this to something large just in case - (e.g., 512 or 1024 or 2048). - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - layer_norm_eps: The epsilon used by LayerNorm. - - dropout: float, dropout rate. - dropatt: float, dropout rate on attention probabilities. - init: str, the initialization scheme, either "normal" or "uniform". - init_range: float, initialize the parameters with a uniform distribution - in [-init_range, init_range]. Only effective when init="uniform". - init_std: float, initialize the parameters with a normal distribution - with mean 0 and stddev init_std. Only effective when init="normal". - mem_len: int, the number of tokens to cache. - reuse_len: int, the number of tokens in the currect batch to be cached - and reused in the future. - bi_data: bool, whether to use bidirectional input pipeline. - Usually set to True during pretraining and False during finetuning. - clamp_len: int, clamp all relative distances larger than clamp_len. - -1 means no clamping. - same_length: bool, whether to use the same attention length for each token. - """ - pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__(self, - vocab_size_or_config_json_file=30145, - emb_dim=2048, - n_layers=12, - n_heads=16, - dropout=0.1, - attention_dropout=0.1, - gelu_activation=True, - sinusoidal_embeddings=False, - causal=False, - asm=False, - n_langs=1, - use_lang_emb=True, - max_position_embeddings=512, - embed_init_std=2048 ** -0.5, - layer_norm_eps=1e-12, - init_std=0.02, - bos_index=0, - eos_index=1, - pad_index=2, - unk_index=3, - mask_index=5, - is_encoder=True, - - finetuning_task=None, - num_labels=2, - summary_type='first', - summary_use_proj=True, - summary_activation=None, - summary_proj_to_labels=True, - summary_first_dropout=0.1, - start_n_top=5, - end_n_top=5, - **kwargs): - """Constructs XLMConfig. - """ - super(XLMConfig, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.n_words = vocab_size_or_config_json_file - self.emb_dim = emb_dim - self.n_layers = n_layers - self.n_heads = n_heads - self.dropout = dropout - self.attention_dropout = attention_dropout - self.gelu_activation = gelu_activation - self.sinusoidal_embeddings = sinusoidal_embeddings - self.causal = causal - self.asm = asm - self.n_langs = n_langs - self.use_lang_emb = use_lang_emb - self.layer_norm_eps = layer_norm_eps - self.bos_index = bos_index - self.eos_index = eos_index - self.pad_index = pad_index - self.unk_index = unk_index - self.mask_index = mask_index - self.is_encoder = is_encoder - self.max_position_embeddings = max_position_embeddings - self.embed_init_std = embed_init_std - self.init_std = init_std - self.finetuning_task = finetuning_task - self.num_labels = num_labels - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_proj_to_labels = summary_proj_to_labels - self.summary_first_dropout = summary_first_dropout - self.start_n_top = start_n_top - self.end_n_top = end_n_top - else: - raise ValueError("First argument must be either a vocabulary size (int)" - " or the path to a pretrained model config file (str)") - - @property - def vocab_size(self): - return self.n_words - - @vocab_size.setter - def vocab_size(self, value): - self.n_words = value - - @property - def hidden_size(self): - return self.emb_dim - - @property - def num_attention_heads(self): - return self.n_heads - - @property - def num_hidden_layers(self): - return self.n_layers def create_sinusoidal_embeddings(n_pos, dim, out): diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 0dfb33a27a..00c15080a1 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -29,9 +29,9 @@ from torch import nn from torch.nn import functional as F from torch.nn import CrossEntropyLoss, MSELoss -from .modeling_utils import (CONFIG_NAME, WEIGHTS_NAME, PretrainedConfig, PreTrainedModel, - SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, - add_start_docstrings) +from .modeling_utils import PreTrainedModel, prune_linear_layer, SequenceSummary, PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits +from .configuration_xlnet import XLNetConfig +from .file_utils import add_start_docstrings logger = logging.getLogger(__name__) @@ -40,10 +40,6 @@ XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = { 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-pytorch_model.bin", 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-pytorch_model.bin", } -XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP = { - 'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-config.json", - 'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-config.json", -} def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None): @@ -192,147 +188,6 @@ def swish(x): ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish} -class XLNetConfig(PretrainedConfig): - """Configuration class to store the configuration of a ``XLNetModel``. - - Args: - vocab_size_or_config_json_file: Vocabulary size of ``inputs_ids`` in ``XLNetModel``. - d_model: Size of the encoder layers and the pooler layer. - n_layer: Number of hidden layers in the Transformer encoder. - n_head: Number of attention heads for each attention layer in - the Transformer encoder. - d_inner: The size of the "intermediate" (i.e., feed-forward) - layer in the Transformer encoder. - ff_activation: The non-linear activation function (function or string) in the - encoder and pooler. If string, "gelu", "relu" and "swish" are supported. - untie_r: untie relative position biases - attn_type: 'bi' for XLNet, 'uni' for Transformer-XL - - dropout: The dropout probabilitiy for all fully connected - layers in the embeddings, encoder, and pooler. - dropatt: The dropout ratio for the attention - probabilities. - initializer_range: The sttdev of the truncated_normal_initializer for - initializing all weight matrices. - layer_norm_eps: The epsilon used by LayerNorm. - - dropout: float, dropout rate. - dropatt: float, dropout rate on attention probabilities. - init: str, the initialization scheme, either "normal" or "uniform". - init_range: float, initialize the parameters with a uniform distribution - in [-init_range, init_range]. Only effective when init="uniform". - init_std: float, initialize the parameters with a normal distribution - with mean 0 and stddev init_std. Only effective when init="normal". - mem_len: int, the number of tokens to cache. - reuse_len: int, the number of tokens in the currect batch to be cached - and reused in the future. - bi_data: bool, whether to use bidirectional input pipeline. - Usually set to True during pretraining and False during finetuning. - clamp_len: int, clamp all relative distances larger than clamp_len. - -1 means no clamping. - same_length: bool, whether to use the same attention length for each token. - finetuning_task: name of the glue task on which the model was fine-tuned if any - """ - pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP - - def __init__(self, - vocab_size_or_config_json_file=32000, - d_model=1024, - n_layer=24, - n_head=16, - d_inner=4096, - ff_activation="gelu", - untie_r=True, - attn_type="bi", - - initializer_range=0.02, - layer_norm_eps=1e-12, - - dropout=0.1, - mem_len=None, - reuse_len=None, - bi_data=False, - clamp_len=-1, - same_length=False, - - finetuning_task=None, - num_labels=2, - summary_type='last', - summary_use_proj=True, - summary_activation='tanh', - summary_last_dropout=0.1, - start_n_top=5, - end_n_top=5, - **kwargs): - """Constructs XLNetConfig. - """ - super(XLNetConfig, self).__init__(**kwargs) - - if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2 - and isinstance(vocab_size_or_config_json_file, unicode)): - with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader: - json_config = json.loads(reader.read()) - for key, value in json_config.items(): - self.__dict__[key] = value - elif isinstance(vocab_size_or_config_json_file, int): - self.n_token = vocab_size_or_config_json_file - self.d_model = d_model - self.n_layer = n_layer - self.n_head = n_head - assert d_model % n_head == 0 - self.d_head = d_model // n_head - self.ff_activation = ff_activation - self.d_inner = d_inner - self.untie_r = untie_r - self.attn_type = attn_type - - self.initializer_range = initializer_range - self.layer_norm_eps = layer_norm_eps - - self.dropout = dropout - self.mem_len = mem_len - self.reuse_len = reuse_len - self.bi_data = bi_data - self.clamp_len = clamp_len - self.same_length = same_length - - self.finetuning_task = finetuning_task - self.num_labels = num_labels - self.summary_type = summary_type - self.summary_use_proj = summary_use_proj - self.summary_activation = summary_activation - self.summary_last_dropout = summary_last_dropout - self.start_n_top = start_n_top - self.end_n_top = end_n_top - else: - raise ValueError("First argument must be either a vocabulary size (int)" - " or the path to a pretrained model config file (str)") - - @property - def max_position_embeddings(self): - return -1 - - @property - def vocab_size(self): - return self.n_token - - @vocab_size.setter - def vocab_size(self, value): - self.n_token = value - - @property - def hidden_size(self): - return self.d_model - - @property - def num_attention_heads(self): - return self.n_head - - @property - def num_hidden_layers(self): - return self.n_layer - - try: from apex.normalization.fused_layer_norm import FusedLayerNorm as XLNetLayerNorm except (ImportError, AttributeError) as e: diff --git a/pytorch_transformers/tests/configuration_common_test.py b/pytorch_transformers/tests/configuration_common_test.py new file mode 100644 index 0000000000..8ee751153c --- /dev/null +++ b/pytorch_transformers/tests/configuration_common_test.py @@ -0,0 +1,63 @@ +# coding=utf-8 +# Copyright 2019 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import shutil +import json +import random +import uuid + +import unittest +import logging + + +class ConfigTester(object): + def __init__(self, parent, config_class=None, **kwargs): + self.parent = parent + self.config_class = config_class + self.inputs_dict = kwargs + + def create_and_test_config_common_properties(self): + config = self.config_class(**self.inputs_dict) + self.parent.assertTrue(hasattr(config, 'vocab_size')) + self.parent.assertTrue(hasattr(config, 'hidden_size')) + self.parent.assertTrue(hasattr(config, 'num_attention_heads')) + self.parent.assertTrue(hasattr(config, 'num_hidden_layers')) + + def create_and_test_config_to_json_string(self): + config = self.config_class(**self.inputs_dict) + obj = json.loads(config.to_json_string()) + for key, value in self.inputs_dict.items(): + self.parent.assertEqual(obj[key], value) + + def create_and_test_config_to_json_file(self): + config_first = self.config_class(**self.inputs_dict) + json_file_path = os.path.join(os.getcwd(), "config_" + str(uuid.uuid4()) + ".json") + config_first.to_json_file(json_file_path) + config_second = self.config_class.from_json_file(json_file_path) + os.remove(json_file_path) + self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) + + def run_common_tests(self): + self.create_and_test_config_common_properties() + self.create_and_test_config_to_json_string() + self.create_and_test_config_to_json_file() + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/pytorch_transformers/tests/modeling_auto_test.py b/pytorch_transformers/tests/modeling_auto_test.py index 09d09b28fc..dfdedbbe61 100644 --- a/pytorch_transformers/tests/modeling_auto_test.py +++ b/pytorch_transformers/tests/modeling_auto_test.py @@ -28,7 +28,8 @@ from pytorch_transformers import (AutoConfig, BertConfig, AutoModelForQuestionAnswering, BertForQuestionAnswering) from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class AutoModelTest(unittest.TestCase): diff --git a/pytorch_transformers/tests/modeling_bert_test.py b/pytorch_transformers/tests/modeling_bert_test.py index e1cce38479..2919cc0336 100644 --- a/pytorch_transformers/tests/modeling_bert_test.py +++ b/pytorch_transformers/tests/modeling_bert_test.py @@ -26,7 +26,8 @@ from pytorch_transformers import (BertConfig, BertModel, BertForMaskedLM, BertForTokenClassification, BertForMultipleChoice) from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class BertModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_common_test.py b/pytorch_transformers/tests/modeling_common_test.py index 8b1a70fcf3..c50d6678d8 100644 --- a/pytorch_transformers/tests/modeling_common_test.py +++ b/pytorch_transformers/tests/modeling_common_test.py @@ -28,9 +28,9 @@ import logging import torch -from pytorch_transformers import PretrainedConfig, PreTrainedModel -from pytorch_transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP -from pytorch_transformers.modeling_gpt2 import GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP +from pytorch_transformers import (PretrainedConfig, PreTrainedModel, + BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, + GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) def _config_zero_init(config): diff --git a/pytorch_transformers/tests/modeling_distilbert_test.py b/pytorch_transformers/tests/modeling_distilbert_test.py index 7b0b6b8266..619dbf5850 100644 --- a/pytorch_transformers/tests/modeling_distilbert_test.py +++ b/pytorch_transformers/tests/modeling_distilbert_test.py @@ -18,13 +18,15 @@ from __future__ import print_function import unittest import shutil +import sys import pytest from pytorch_transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) from pytorch_transformers.modeling_distilbert import DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class DistilBertModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_gpt2_test.py b/pytorch_transformers/tests/modeling_gpt2_test.py index 1786ada54c..2717805120 100644 --- a/pytorch_transformers/tests/modeling_gpt2_test.py +++ b/pytorch_transformers/tests/modeling_gpt2_test.py @@ -24,7 +24,8 @@ import shutil from pytorch_transformers import (GPT2Config, GPT2Model, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2DoubleHeadsModel) -from .modeling_common_test import CommonTestCases, ConfigTester, ids_tensor +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class GPT2ModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_openai_test.py b/pytorch_transformers/tests/modeling_openai_test.py index 0fcb4b7d64..dbef6c52eb 100644 --- a/pytorch_transformers/tests/modeling_openai_test.py +++ b/pytorch_transformers/tests/modeling_openai_test.py @@ -24,7 +24,8 @@ import shutil from pytorch_transformers import (OpenAIGPTConfig, OpenAIGPTModel, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel) -from .modeling_common_test import CommonTestCases, ConfigTester, ids_tensor +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class OpenAIGPTModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_roberta_test.py b/pytorch_transformers/tests/modeling_roberta_test.py index 0279f3756b..69981af222 100644 --- a/pytorch_transformers/tests/modeling_roberta_test.py +++ b/pytorch_transformers/tests/modeling_roberta_test.py @@ -24,7 +24,8 @@ import torch from pytorch_transformers import (RobertaConfig, RobertaModel, RobertaForMaskedLM, RobertaForSequenceClassification) from pytorch_transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class RobertaModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_transfo_xl_test.py b/pytorch_transformers/tests/modeling_transfo_xl_test.py index e3c0fbcdf0..a060432cc4 100644 --- a/pytorch_transformers/tests/modeling_transfo_xl_test.py +++ b/pytorch_transformers/tests/modeling_transfo_xl_test.py @@ -28,7 +28,8 @@ import torch from pytorch_transformers import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel) from pytorch_transformers.modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import ConfigTester, CommonTestCases, ids_tensor +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class TransfoXLModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_xlm_test.py b/pytorch_transformers/tests/modeling_xlm_test.py index 4308c18d45..dcd0963477 100644 --- a/pytorch_transformers/tests/modeling_xlm_test.py +++ b/pytorch_transformers/tests/modeling_xlm_test.py @@ -23,7 +23,8 @@ import pytest from pytorch_transformers import (XLMConfig, XLMModel, XLMWithLMHeadModel, XLMForQuestionAnswering, XLMForSequenceClassification) from pytorch_transformers.modeling_xlm import XLM_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class XLMModelTest(CommonTestCases.CommonModelTester): diff --git a/pytorch_transformers/tests/modeling_xlnet_test.py b/pytorch_transformers/tests/modeling_xlnet_test.py index 290c5766e2..4445bc17ac 100644 --- a/pytorch_transformers/tests/modeling_xlnet_test.py +++ b/pytorch_transformers/tests/modeling_xlnet_test.py @@ -28,7 +28,8 @@ import torch from pytorch_transformers import (XLNetConfig, XLNetModel, XLNetLMHeadModel, XLNetForSequenceClassification, XLNetForQuestionAnswering) from pytorch_transformers.modeling_xlnet import XLNET_PRETRAINED_MODEL_ARCHIVE_MAP -from .modeling_common_test import ConfigTester, CommonTestCases, ids_tensor +from .modeling_common_test import (CommonTestCases, ids_tensor) +from .configuration_common_test import ConfigTester class XLNetModelTest(CommonTestCases.CommonModelTester):