diff --git a/hubconf.py b/hubconf.py index ba09cbab3c..f833620780 100644 --- a/hubconf.py +++ b/hubconf.py @@ -16,4 +16,15 @@ from hubconfs.gpt_hubconf import ( openAIGPTModel, openAIGPTLMHeadModel, openAIGPTDoubleHeadsModel -) \ No newline at end of file +) +from hubconfs.gpt2_hubconf import ( + gpt2Tokenizer, + gpt2Model, + gpt2LMHeadModel, + gpt2DoubleHeadsModel +) +from hubconfs.transformer_xl_hubconf import ( + transformerXLTokenizer, + transformerXLModel, + transformerXLLMHeadModel +) diff --git a/hubconfs/gpt2_hubconf.py b/hubconfs/gpt2_hubconf.py new file mode 100644 index 0000000000..26b53e8b03 --- /dev/null +++ b/hubconfs/gpt2_hubconf.py @@ -0,0 +1,164 @@ +from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer +from pytorch_pretrained_bert.modeling_gpt2 import ( + GPT2Model, + GPT2LMHeadModel, + GPT2DoubleHeadsModel +) + +# A lot of models share the same param doc. Use a decorator +# to save typing +gpt2_docstring = """ + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `gpt2` + - a path or url to a pretrained model archive containing: + . `gpt2_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a GPT2Model instance + - a path or url to a pretrained model archive containing: + . `gpt2_config.json` a configuration file for the model + . a TensorFlow checkpoint with trained weights + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models + *inputs, **kwargs: additional input for the specific GPT-2 class +""" + + +def _append_from_pretrained_docstring(docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + docstr + return fn + return docstring_decorator + + +def gpt2Tokenizer(*args, **kwargs): + """ + Instantiate a GPT-2 BPE tokenizer for OpenAI GPT-2 from a pre-trained/customized vocab file. + Peculiarities: + - Byte-level BPE + + Args: + pretrained_model_name_or_path: Path to pretrained model archive + or one of pre-trained vocab configs below. + * gpt2 + Keyword args: + special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...) + Default: None + max_len: An artificial maximum length to truncate tokenized sequences to; + Effective maximum length is always the minimum of this + value (if specified) and the underlying BERT model's + sequence length. + Default: None + + Example: + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2') + + >>> text = "Who was Jim Henson ?" + >>> indexed_tokens = tokenizer.encode(tokenized_text) + """ + tokenizer = GPT2Tokenizer.from_pretrained(*args, **kwargs) + return tokenizer + + +@_append_from_pretrained_docstring(gpt2_docstring) +def gpt2Model(*args, **kwargs): + """ + gpt2Model is the basic OpenAI GPT-2 Transformer model based on + identical stacked masked self-attention blocks and pre-trained + on large scale dataset using language modeling signal. + + Example: + # Load the tokenizer + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2') + + # Prepare tokenized input + >>> text_1 = "Who was Jim Henson ?" + >>> text_2 = "Jim Henson was a puppeteer" + >>> indexed_tokens_1 = tokenizer.encode(text_1) + >>> indexed_tokens_2 = tokenizer.encode(text_2) + >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) + >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) + + # Load gpt2Model + >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Model', 'gpt2') + >>> model.eval() + + # Predict hidden states features for each layer + # past can be used to reuse precomputed hidden state in a subsequent predictions + >>> with torch.no_grad(): + hidden_states_1, past = model(tokens_tensor_1) + hidden_states_2, past = model(tokens_tensor_2, past=past) + """ + model = GPT2Model.from_pretrained(*args, **kwargs) + return model + + +@_append_from_pretrained_docstring(gpt2_docstring) +def gpt2LMHeadModel(*args, **kwargs): + """ + gpt2LMHeadModel is the OpenAI GPT-2 Transformer model with the + tied (pre-trained) language modeling head on top. + + Example: + # Load the tokenizer + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2') + + # Prepare tokenized input + >>> text_1 = "Who was Jim Henson ?" + >>> text_2 = "Jim Henson was a puppeteer" + >>> indexed_tokens_1 = tokenizer.encode(text_1) + >>> indexed_tokens_2 = tokenizer.encode(text_2) + >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) + >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) + + # Load gpt2LMHeadModel + >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2LMHeadModel', 'gpt2') + >>> model.eval() + + # Predict hidden states features for each layer + # past can be used to reuse precomputed hidden state in a subsequent predictions + >>> with torch.no_grad(): + predictions_1, past = model(tokens_tensor_1) + predictions_2, past = model(tokens_tensor_2, past=past) + + # Get the predicted last token + >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() + >>> predicted_token = tokenizer.decode([predicted_index]) + >>> assert predicted_token == ' who' + """ + model = GPT2LMHeadModel.from_pretrained(*args, **kwargs) + return model + + +@_append_from_pretrained_docstring(gpt2_docstring) +def gpt2DoubleHeadsModel(*args, **kwargs): + """ + gpt2DoubleHeadsModel is the OpenAI GPT-2 Transformer model with the + tied (pre-trained) language modeling head and a multiple choice + classification head (only initialized, not pre-trained). + + Example: + # Load the tokenizer + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2Tokenizer', 'gpt2') + + # Prepare tokenized input + >>> text = "Who was Jim Henson ?" + >>> indexed_tokens = tokenizer.encode(text) + >>> tokens_tensor = torch.tensor([indexed_tokens]) + >>> mc_token_ids = torch.LongTensor([ [len(indexed_tokens)] ]) + + # Load gpt2DoubleHeadsModel + >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'gpt2DoubleHeadsModel', 'gpt2') + >>> model.eval() + + # Predict hidden states features for each layer + >>> with torch.no_grad(): + lm_logits, multiple_choice_logits, presents = model(tokens_tensor, mc_token_ids) + """ + model = GPT2DoubleHeadsModel.from_pretrained(*args, **kwargs) + return model diff --git a/hubconfs/transformer_xl_hubconf.py b/hubconfs/transformer_xl_hubconf.py new file mode 100644 index 0000000000..d5c697547e --- /dev/null +++ b/hubconfs/transformer_xl_hubconf.py @@ -0,0 +1,130 @@ +from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer +from pytorch_pretrained_bert.modeling_transfo_xl import ( + TransfoXLModel, + TransfoXLLMHeadModel +) + +# A lot of models share the same param doc. Use a decorator +# to save typing +transformer_xl_docstring = """ + Transformer XL use a relative positioning (with sinusiodal patterns) and adaptive softmax inputs which means that: + - you don't need to specify positioning embeddings indices + - the tokens in the vocabulary have to be sorted to decreasing frequency. + + Params: + pretrained_model_name_or_path: either: + - a str with the name of a pre-trained model to load selected in the list of: + . `transfo-xl-wt103` + - a path or url to a pretrained model archive containing: + . `transfo_xl_config.json` a configuration file for the model + . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance + - a path or url to a pretrained model archive containing: + . `transfo_xl_config.json` a configuration file for the model + . `model.chkpt` a TensorFlow checkpoint + from_tf: should we load the weights from a locally saved TensorFlow checkpoint + cache_dir: an optional path to a folder in which the pre-trained models will be cached. + state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models + *inputs, **kwargs: additional input for the specific TransformerXL class +""" + + +def _append_from_pretrained_docstring(docstr): + def docstring_decorator(fn): + fn.__doc__ = fn.__doc__ + docstr + return fn + return docstring_decorator + + +def transformerXLTokenizer(*args, **kwargs): + """ + Instantiate a Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl + + Args: + pretrained_model_name_or_path: Path to pretrained model archive + or one of pre-trained vocab configs below. + * transfo-xl-wt103 + + Example: + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103') + + >>> text = "Who was Jim Henson ?" + >>> tokenized_text = tokenizer.tokenize(tokenized_text) + >>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text) + """ + tokenizer = TransfoXLTokenizer.from_pretrained(*args, **kwargs) + return tokenizer + + +@_append_from_pretrained_docstring(transformer_xl_docstring) +def transformerXLModel(*args, **kwargs): + """ + transformerXLModel is the basic Transformer XL model. + + Example: + # Load the tokenizer + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103') + + # Prepare tokenized input + >>> text_1 = "Who was Jim Henson ?" + >>> text_2 = "Jim Henson was a puppeteer" + >>> tokenized_text_1 = tokenizer.tokenize(text_1) + >>> tokenized_text_2 = tokenizer.tokenize(text_2) + >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) + >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) + >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) + >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) + + # Load transformerXLModel + >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLModel', 'transfo-xl-wt103') + >>> model.eval() + + # Predict hidden states features for each layer + # We can re-use the memory cells in a subsequent call to attend a longer context + >>> with torch.no_grad(): + hidden_states_1, mems_1 = model(tokens_tensor_1) + hidden_states_2, mems_2 = model(tokens_tensor_2, mems=mems_1) + """ + model = TransfoXLModel.from_pretrained(*args, **kwargs) + return model + + +@_append_from_pretrained_docstring(transformer_xl_docstring) +def transformerXLLMHeadModel(*args, **kwargs): + """ + transformerXLModel is the basic Transformer XL model with the + tied (pre-trained) language modeling head on top. + + Example: + # Load the tokenizer + >>> import torch + >>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLTokenizer', 'transfo-xl-wt103') + + # Prepare tokenized input + >>> text_1 = "Who was Jim Henson ?" + >>> text_2 = "Jim Henson was a puppeteer" + >>> tokenized_text_1 = tokenizer.tokenize(text_1) + >>> tokenized_text_2 = tokenizer.tokenize(text_2) + >>> indexed_tokens_1 = tokenizer.convert_tokens_to_ids(tokenized_text_1) + >>> indexed_tokens_2 = tokenizer.convert_tokens_to_ids(tokenized_text_2) + >>> tokens_tensor_1 = torch.tensor([indexed_tokens_1]) + >>> tokens_tensor_2 = torch.tensor([indexed_tokens_2]) + + # Load transformerXLLMHeadModel + >>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'transformerXLLMHeadModel', 'transfo-xl-wt103') + >>> model.eval() + + # Predict hidden states features for each layer + # We can re-use the memory cells in a subsequent call to attend a longer context + >>> with torch.no_grad(): + predictions_1, mems_1 = model(tokens_tensor_1) + predictions_2, mems_2 = model(tokens_tensor_2, mems=mems_1) + + # Get the predicted last token + >>> predicted_index = torch.argmax(predictions_2[0, -1, :]).item() + >>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0] + >>> assert predicted_token == 'who' + """ + model = TransfoXLLMHeadModel.from_pretrained(*args, **kwargs) + return model diff --git a/pytorch_pretrained_bert/modeling_gpt2.py b/pytorch_pretrained_bert/modeling_gpt2.py index d462fe04ef..396364d549 100644 --- a/pytorch_pretrained_bert/modeling_gpt2.py +++ b/pytorch_pretrained_bert/modeling_gpt2.py @@ -406,9 +406,7 @@ class GPT2PreTrainedModel(nn.Module): module.bias.data.zero_() @classmethod - def from_pretrained( - cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs - ): + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a GPT2PreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -426,8 +424,15 @@ class GPT2PreTrainedModel(nn.Module): from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionary (collections.OrderedDict object) to use instead of pre-trained models - *inputs, **kwargs: additional input for the specific GPT class + *inputs, **kwargs: additional input for the specific GPT2 class """ + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] @@ -770,7 +775,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): config = modeling_gpt2.GPT2Config() - model = modeling_gpt2.GPT2LMHeadModel(config) + model = modeling_gpt2.GPT2DoubleHeadsModel(config) lm_logits, multiple_choice_logits, presents = model(input_ids, mc_token_ids) ``` """ diff --git a/pytorch_pretrained_bert/modeling_openai.py b/pytorch_pretrained_bert/modeling_openai.py index f805f63912..2b44803584 100644 --- a/pytorch_pretrained_bert/modeling_openai.py +++ b/pytorch_pretrained_bert/modeling_openai.py @@ -815,7 +815,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): config = modeling_openai.OpenAIGPTConfig() - model = modeling_openai.OpenAIGPTLMHeadModel(config) + model = modeling_openai.OpenAIGPTDoubleHeadsModel(config) lm_logits, multiple_choice_logits = model(input_ids, mc_token_ids) ``` """ diff --git a/pytorch_pretrained_bert/modeling_transfo_xl.py b/pytorch_pretrained_bert/modeling_transfo_xl.py index e8fffc5b60..e70a29af57 100644 --- a/pytorch_pretrained_bert/modeling_transfo_xl.py +++ b/pytorch_pretrained_bert/modeling_transfo_xl.py @@ -888,8 +888,7 @@ class TransfoXLPreTrainedModel(nn.Module): pass @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, state_dict=None, cache_dir=None, - from_tf=False, *inputs, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path, *inputs, **kwargs): """ Instantiate a TransfoXLPreTrainedModel from a pre-trained model file or a pytorch state dict. Download and cache the pre-trained model file if needed. @@ -897,19 +896,25 @@ class TransfoXLPreTrainedModel(nn.Module): Params: pretrained_model_name_or_path: either: - a str with the name of a pre-trained model to load selected in the list of: - . `transfo-xl` + . `transfo-xl-wt103` - a path or url to a pretrained model archive containing: . `transfo_xl_config.json` a configuration file for the model . `pytorch_model.bin` a PyTorch dump of a TransfoXLModel instance - a path or url to a pretrained model archive containing: - . `bert_config.json` a configuration file for the model + . `transfo_xl_config.json` a configuration file for the model . `model.chkpt` a TensorFlow checkpoint from_tf: should we load the weights from a locally saved TensorFlow checkpoint cache_dir: an optional path to a folder in which the pre-trained models will be cached. state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models - *inputs, **kwargs: additional input for the specific Bert class - (ex: num_labels for BertForSequenceClassification) + *inputs, **kwargs: additional input for the specific TransformerXL class """ + state_dict = kwargs.get('state_dict', None) + kwargs.pop('state_dict', None) + cache_dir = kwargs.get('cache_dir', None) + kwargs.pop('cache_dir', None) + from_tf = kwargs.get('from_tf', False) + kwargs.pop('from_tf', None) + if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP: archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path] config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path] diff --git a/pytorch_pretrained_bert/tokenization_gpt2.py b/pytorch_pretrained_bert/tokenization_gpt2.py index c66af3ff13..af75cac4dc 100644 --- a/pytorch_pretrained_bert/tokenization_gpt2.py +++ b/pytorch_pretrained_bert/tokenization_gpt2.py @@ -93,7 +93,7 @@ class GPT2Tokenizer(object): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): """ - Instantiate a PreTrainedBertModel from a pre-trained model file. + Instantiate a GPT2Tokenizer from a pre-trained model file. Download and cache the pre-trained model file if needed. """ if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: