Merge pull request #651 from huggingface/gpt_torchhub

Add GPT* compatibility to torchhub
This commit is contained in:
Thomas Wolf 2019-05-31 14:44:52 +02:00 committed by GitHub
commit 2a329c6186
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 456 additions and 252 deletions

View File

@ -1,248 +1,19 @@
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import (
BertModel,
BertForNextSentencePrediction,
BertForMaskedLM,
BertForMultipleChoice,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']
# A lot of models share the same param doc. Use a decorator
# to save typing
bert_docstring = """
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow
checkpoint
cache_dir: an optional path to a folder in which the pre-trained models
will be cached.
state_dict: an optional state dictionnary
(collections.OrderedDict object) to use instead of Google
pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
def _append_from_pretrained_docstring(docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + docstr
return fn
return docstring_decorator
def bertTokenizer(*args, **kwargs):
"""
Instantiate a BertTokenizer from a pre-trained/customized vocab file
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* bert-base-uncased
* bert-large-uncased
* bert-base-cased
* bert-large-cased
* bert-base-multilingual-uncased
* bert-base-multilingual-cased
* bert-base-chinese
Keyword args:
cache_dir: an optional path to a specific directory to download and cache
the pre-trained model weights.
Default: None
do_lower_case: Whether to lower case the input.
Only has an effect when do_wordpiece_only=False
Default: True
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
Default: True
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
Default: None
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
Default: ["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]
Example:
>>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks)
[8667, 28136, 1291, 28125]
"""
tokenizer = BertTokenizer.from_pretrained(*args, **kwargs)
return tokenizer
@_append_from_pretrained_docstring(bert_docstring)
def bertModel(*args, **kwargs):
"""
BertModel is the basic BERT Transformer model with a layer of summed token,
position and sequence embeddings followed by a series of identical
self-attention blocks (12 for BERT-base, 24 for BERT-large).
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
['[CLS]', 'Who', 'was', 'Jim', 'He', '##nson', '?', '[SEP]', 'Jim', 'He', '##nson', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
tensor([[101, 2627, 1108, 3104, 1124, 15703, 136, 102, 3104, 1124, 15703, 1108, 170, 16797, 8284, 102]])
>>> segments_tensors = torch.tensor([segments_ids])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
# Load bertModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertModel', 'bert-base-cased', force_reload=False)
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
encoded_layers, _ = model(tokens_tensor, segments_tensors)
"""
model = BertModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForNextSentencePrediction(*args, **kwargs):
"""
BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence
classification head.
"""
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForPreTraining(*args, **kwargs):
"""
BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads
- the masked language modeling head, and
- the next sentence classification head.
"""
model = BertForPreTraining.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForMaskedLM(*args, **kwargs):
"""
BertForMaskedLM includes the BertModel Transformer followed by the
(possibly) pre-trained masked language modeling head.
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> masked_index = 8
>>> tokenized_text[masked_index] = '[MASK]'
['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForMaskedLM
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMaskedLM', 'bert-base-cased', force_reload=False)
>>> model.eval()
# Predict all tokens
>>> with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
>>> predicted_index = torch.argmax(predictions[0, masked_index]).item()
>>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
'henson'
"""
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForSequenceClassification(*args, **kwargs):
"""
BertForSequenceClassification is a fine-tuning model that includes
BertModel and a sequence-level (sequence or pair of sequences) classifier
on top of the BertModel.
The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper).
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForMultipleChoice(*args, **kwargs):
"""
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel.
Args:
num_choices: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2, force_reload=True)
"""
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForQuestionAnswering(*args, **kwargs):
"""
BertForQuestionAnswering is a fine-tuning model that includes BertModel
with a token-level classifiers on top of the full sequence of last hidden
states.
"""
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForTokenClassification(*args, **kwargs):
"""
BertForTokenClassification is a fine-tuning model that includes BertModel
and a token-level classifier on top of the BertModel.
The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence.
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model
from hubconfs.bert_hubconf import (
bertTokenizer,
bertModel,
bertForNextSentencePrediction,
bertForPreTraining,
bertForMaskedLM,
bertForSequenceClassification,
bertForMultipleChoice,
bertForQuestionAnswering,
bertForTokenClassification
)
from hubconfs.gpt_hubconf import (
openAIGPTTokenizer,
openAIGPTModel,
openAIGPTLMHeadModel,
openAIGPTDoubleHeadsModel
)

246
hubconfs/bert_hubconf.py Normal file
View File

@ -0,0 +1,246 @@
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import (
BertModel,
BertForNextSentencePrediction,
BertForMaskedLM,
BertForMultipleChoice,
BertForPreTraining,
BertForQuestionAnswering,
BertForSequenceClassification,
BertForTokenClassification,
)
# A lot of models share the same param doc. Use a decorator
# to save typing
bert_docstring = """
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load
. `bert-base-uncased`
. `bert-large-uncased`
. `bert-base-cased`
. `bert-large-cased`
. `bert-base-multilingual-uncased`
. `bert-base-multilingual-cased`
. `bert-base-chinese`
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining
instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
from_tf: should we load the weights from a locally saved TensorFlow
checkpoint
cache_dir: an optional path to a folder in which the pre-trained models
will be cached.
state_dict: an optional state dictionnary
(collections.OrderedDict object) to use instead of Google
pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
def _append_from_pretrained_docstring(docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + docstr
return fn
return docstring_decorator
def bertTokenizer(*args, **kwargs):
"""
Instantiate a BertTokenizer from a pre-trained/customized vocab file
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* bert-base-uncased
* bert-large-uncased
* bert-base-cased
* bert-large-cased
* bert-base-multilingual-uncased
* bert-base-multilingual-cased
* bert-base-chinese
Keyword args:
cache_dir: an optional path to a specific directory to download and cache
the pre-trained model weights.
Default: None
do_lower_case: Whether to lower case the input.
Only has an effect when do_wordpiece_only=False
Default: True
do_basic_tokenize: Whether to do basic tokenization before wordpiece.
Default: True
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
Default: None
never_split: List of tokens which will never be split during tokenization.
Only has an effect when do_wordpiece_only=False
Default: ["[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"]
Example:
>>> sentence = 'Hello, World!'
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
>>> toks = tokenizer.tokenize(sentence)
['Hello', '##,', 'World', '##!']
>>> ids = tokenizer.convert_tokens_to_ids(toks)
[8667, 28136, 1291, 28125]
"""
tokenizer = BertTokenizer.from_pretrained(*args, **kwargs)
return tokenizer
@_append_from_pretrained_docstring(bert_docstring)
def bertModel(*args, **kwargs):
"""
BertModel is the basic BERT Transformer model with a layer of summed token,
position and sequence embeddings followed by a series of identical
self-attention blocks (12 for BERT-base, 24 for BERT-large).
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
['[CLS]', 'Who', 'was', 'Jim', 'He', '##nson', '?', '[SEP]', 'Jim', 'He', '##nson', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
tensor([[101, 2627, 1108, 3104, 1124, 15703, 136, 102, 3104, 1124, 15703, 1108, 170, 16797, 8284, 102]])
>>> segments_tensors = torch.tensor([segments_ids])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]])
# Load bertModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertModel', 'bert-base-cased', force_reload=False)
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
encoded_layers, _ = model(tokens_tensor, segments_tensors)
"""
model = BertModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForNextSentencePrediction(*args, **kwargs):
"""
BERT model with next sentence prediction head.
This module comprises the BERT model followed by the next sentence
classification head.
"""
model = BertForNextSentencePrediction.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForPreTraining(*args, **kwargs):
"""
BERT model with pre-training heads.
This module comprises the BERT model followed by the two pre-training heads
- the masked language modeling head, and
- the next sentence classification head.
"""
model = BertForPreTraining.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForMaskedLM(*args, **kwargs):
"""
BertForMaskedLM includes the BertModel Transformer followed by the
(possibly) pre-trained masked language modeling head.
Example:
# Load the tokenizer
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertTokenizer', 'bert-base-cased', do_basic_tokenize=False, force_reload=False)
# Prepare tokenized input
>>> text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
>>> tokenized_text = tokenizer.tokenize(text)
>>> masked_index = 8
>>> tokenized_text[masked_index] = '[MASK]'
['[CLS]', 'who', 'was', 'jim', 'henson', '?', '[SEP]', 'jim', '[MASK]', 'was', 'a', 'puppet', '##eer', '[SEP]']
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> segments_ids = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> segments_tensors = torch.tensor([segments_ids])
# Load bertForMaskedLM
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMaskedLM', 'bert-base-cased', force_reload=False)
>>> model.eval()
# Predict all tokens
>>> with torch.no_grad():
predictions = model(tokens_tensor, segments_tensors)
>>> predicted_index = torch.argmax(predictions[0, masked_index]).item()
>>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
'henson'
"""
model = BertForMaskedLM.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForSequenceClassification(*args, **kwargs):
"""
BertForSequenceClassification is a fine-tuning model that includes
BertModel and a sequence-level (sequence or pair of sequences) classifier
on top of the BertModel.
The sequence-level classifier is a linear layer that takes as input the
last hidden state of the first character in the input sequence
(see Figures 3a and 3b in the BERT paper).
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForSequenceClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForMultipleChoice(*args, **kwargs):
"""
BertForMultipleChoice is a fine-tuning model that includes BertModel and a
linear layer on top of the BertModel.
Args:
num_choices: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForMultipleChoice', 'bert-base-cased', num_choices=2, force_reload=True)
"""
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForQuestionAnswering(*args, **kwargs):
"""
BertForQuestionAnswering is a fine-tuning model that includes BertModel
with a token-level classifiers on top of the full sequence of last hidden
states.
"""
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(bert_docstring)
def bertForTokenClassification(*args, **kwargs):
"""
BertForTokenClassification is a fine-tuning model that includes BertModel
and a token-level classifier on top of the BertModel.
The token-level classifier is a linear layer that takes as input the last
hidden state of the sequence.
Args:
num_labels: the number (>=2) of classes for the classifier.
Example:
>>> torch.hub.load('huggingface/pytorch-pretrained-BERT', 'bertForTokenClassification', 'bert-base-cased', num_labels=2, force_reload=True)
"""
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model

183
hubconfs/gpt_hubconf.py Normal file
View File

@ -0,0 +1,183 @@
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer
from pytorch_pretrained_bert.modeling_openai import (
OpenAIGPTModel,
OpenAIGPTLMHeadModel,
OpenAIGPTDoubleHeadsModel
)
# Dependecies that are not specified in global hubconf.py
specific_dependencies = ['spacy', 'ftfy']
# A lot of models share the same param doc. Use a decorator
# to save typing
gpt_docstring = """
OpenAI GPT use a single embedding matrix to store the word and special embeddings.
Special tokens embeddings are additional tokens that are not pre-trained: [SEP], [CLS]...
Special tokens need to be trained during the fine-tuning if you use them.
The number of special embeddings can be controled using the `set_num_special_tokens(num_special_tokens)` function.
The embeddings are ordered as follow in the token embeddings matrice:
[0, ----------------------
... -> word embeddings
config.vocab_size - 1, ______________________
config.vocab_size,
... -> special embeddings
config.vocab_size + config.n_special - 1] ______________________
where total_tokens_embeddings can be obtained as config.total_tokens_embeddings and is:
total_tokens_embeddings = config.vocab_size + config.n_special
You should use the associate indices to index the embeddings.
Params:
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `openai-gpt`
- a path or url to a pretrained model archive containing:
. `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `openai-gpt-config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object)
to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
"""
def _append_from_pretrained_docstring(docstr):
def docstring_decorator(fn):
fn.__doc__ = fn.__doc__ + docstr
return fn
return docstring_decorator
def openAIGPTTokenizer(*args, **kwargs):
"""
Instantiate a BPE tokenizer for OpenAI GPT from a pre-trained/customized vocab file.
Peculiarities:
- lower case all inputs
- uses SpaCy tokenizer ('en' model) and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
- argument special_tokens and function set_special_tokens:
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
Args:
pretrained_model_name_or_path: Path to pretrained model archive
or one of pre-trained vocab configs below.
* openai-gpt
Keyword args:
special_tokens: Special tokens in vocabulary that are not pretrained ([SEP], [CLS]...)
Default: None
max_len: An artificial maximum length to truncate tokenized sequences to;
Effective maximum length is always the minimum of this
value (if specified) and the underlying BERT model's
sequence length.
Default: None
Example:
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTTokenizer', 'openai-gpt')
>>> text = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
[763, 509, 4265, 2298, 945, 257, 4265, 2298, 945, 509, 246, 10148, 39041, 483]
"""
tokenizer = OpenAIGPTTokenizer.from_pretrained(*args, **kwargs)
return tokenizer
@_append_from_pretrained_docstring(gpt_docstring)
def openAIGPTModel(*args, **kwargs):
"""
OpenAIGPTModel is the basic OpenAI GPT Transformer model based on
identical stacked masked self-attention blocks and pre-trained
on large scale dataset using language modeling signal.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTTokenizer', 'openai-gpt')
# Prepare tokenized input
>>> text = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> tokens_tensor = torch.tensor([indexed_tokens])
# Load openAIGPTModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTModel', 'openai-gpt')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
hidden_states = model(tokens_tensor)
"""
model = OpenAIGPTModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(gpt_docstring)
def openAIGPTLMHeadModel(*args, **kwargs):
"""
OpenAIGPTLMHeadModel is the OpenAI GPT Transformer model with the
tied (pre-trained) language modeling head on top.
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTTokenizer', 'openai-gpt')
# Prepare tokenized input
>>> text = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> tokens_tensor = torch.tensor([indexed_tokens])
# Load openAIGPTLMHeadModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTLMHeadModel', 'openai-gpt')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
predictions = model(tokens_tensor)
# Get the predicted last token
>>> predicted_index = torch.argmax(predictions[0, -1, :]).item()
>>> predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
'.</w>'
"""
model = OpenAIGPTLMHeadModel.from_pretrained(*args, **kwargs)
return model
@_append_from_pretrained_docstring(gpt_docstring)
def openAIGPTDoubleHeadsModel(*args, **kwargs):
"""
OpenAIGPTDoubleHeadsModel is the OpenAI GPT Transformer model with the
tied (pre-trained) language modeling head and a multiple choice
classification head (only initialized, not pre-trained).
Example:
# Load the tokenizer
>>> import torch
>>> tokenizer = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTTokenizer', 'openai-gpt')
# Prepare tokenized input
>>> text = "Who was Jim Henson ? Jim Henson was a puppeteer"
>>> tokenized_text = tokenizer.tokenize(text)
>>> indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
>>> tokens_tensor = torch.tensor([indexed_tokens])
>>> mc_token_ids = torch.LongTensor([ [len(tokenized_text)] ])
# Load openAIGPTDoubleHeadsModel
>>> model = torch.hub.load('huggingface/pytorch-pretrained-BERT', 'openAIGPTDoubleHeadsModel', 'openai-gpt')
>>> model.eval()
# Predict hidden states features for each layer
>>> with torch.no_grad():
lm_logits, multiple_choice_logits = model(tokens_tensor, mc_token_ids)
"""
model = OpenAIGPTDoubleHeadsModel.from_pretrained(*args, **kwargs)
return model

View File

@ -419,9 +419,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
pass
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path, num_special_tokens=None, state_dict=None, cache_dir=None, from_tf=False, *inputs, **kwargs
):
def from_pretrained(cls, pretrained_model_name_or_path, num_special_tokens=None, *inputs, **kwargs):
"""
Instantiate a OpenAIGPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
@ -434,14 +432,20 @@ class OpenAIGPTPreTrainedModel(nn.Module):
. `openai_gpt_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a OpenAIGPTModel instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `openai-gpt-config.json` a configuration file for the model
. a series of NumPy files containing OpenAI TensorFlow trained weights
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use instead of pre-trained models
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
*inputs, **kwargs: additional input for the specific OpenAI-GPT class
"""
state_dict = kwargs.get('state_dict', None)
kwargs.pop('state_dict', None)
cache_dir = kwargs.get('cache_dir', None)
kwargs.pop('cache_dir', None)
from_tf = kwargs.get('from_tf', False)
kwargs.pop('from_tf', None)
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
archive_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]