Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
parent
47a551d17b
commit
d4c2cb402d
|
@ -63,33 +63,33 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||
| | | | Trained on uncased German text by DBMDZ |
|
||||
| | | (see `details on dbmdz repository <https://github.com/dbmdz/german-bert>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``cl-tohoku/bert-base-japanese`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on Japanese text. Text is tokenized with MeCab and WordPiece. |
|
||||
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
|
||||
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``cl-tohoku/bert-base-japanese-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized with MeCab and WordPiece. |
|
||||
| | | | `MeCab <https://taku910.github.io/mecab/>`__ is required for tokenization. |
|
||||
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``cl-tohoku/bert-base-japanese-char`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on Japanese text. Text is tokenized into characters. |
|
||||
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``cl-tohoku/bert-base-japanese-char-whole-word-masking`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on Japanese text using Whole-Word-Masking. Text is tokenized into characters. |
|
||||
| | | (see `details on cl-tohoku repository <https://github.com/cl-tohoku/bert-japanese>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``TurkuNLP/bert-base-finnish-cased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased Finnish text. |
|
||||
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``TurkuNLP/bert-base-finnish-uncased-v1`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on uncased Finnish text. |
|
||||
| | | (see `details on turkunlp.org <http://turkunlp.org/FinBERT/>`__). |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | ``wietsedv/bert-base-dutch-cased`` | | 12-layer, 768-hidden, 12-heads, 110M parameters. |
|
||||
| | | | Trained on cased Dutch text. |
|
||||
| | | (see `details on wietsedv repository <https://github.com/wietsedv/bertje/>`__). |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
@ -259,32 +259,32 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||
| | ``xlm-roberta-large`` | | ~355M parameters with 24-layers, 1027-hidden-state, 4096 feed-forward hidden-state, 16-heads, |
|
||||
| | | | Trained on 2.5 TB of newly created clean CommonCrawl data in 100 languages |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| FlauBERT | ``flaubert-small-cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters |
|
||||
| FlauBERT | ``flaubert/flaubert_small_cased`` | | 6-layer, 512-hidden, 8-heads, 54M parameters |
|
||||
| | | | FlauBERT small architecture |
|
||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``flaubert-base-uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters |
|
||||
| | ``flaubert/flaubert_base_uncased`` | | 12-layer, 768-hidden, 12-heads, 137M parameters |
|
||||
| | | | FlauBERT base architecture with uncased vocabulary |
|
||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``flaubert-base-cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters |
|
||||
| | ``flaubert/flaubert_base_cased`` | | 12-layer, 768-hidden, 12-heads, 138M parameters |
|
||||
| | | | FlauBERT base architecture with cased vocabulary |
|
||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``flaubert-large-cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters |
|
||||
| | ``flaubert/flaubert_large_cased`` | | 24-layer, 1024-hidden, 16-heads, 373M parameters |
|
||||
| | | | FlauBERT large architecture |
|
||||
| | | (see `details <https://github.com/getalp/Flaubert>`__) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Bart | ``bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
|
||||
| Bart | ``facebook/bart-large`` | | 24-layer, 1024-hidden, 16-heads, 406M parameters |
|
||||
| | | (see `details <https://github.com/pytorch/fairseq/tree/master/examples/bart>`_) |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||
| | ``facebook/bart-large-mnli`` | | Adds a 2 layer classification head with 1 million parameters |
|
||||
| | | | bart-large base architecture with a classification head, finetuned on MNLI |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
|
||||
| | ``facebook/bart-large-cnn`` | | 12-layer, 1024-hidden, 16-heads, 406M parameters (same as base) |
|
||||
| | | | bart-large base architecture finetuned on cnn summarization task |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
|
||||
| | ``facebook/mbart-large-en-ro`` | | 12-layer, 1024-hidden, 16-heads, 880M parameters |
|
||||
| | | | bart-large architecture pretrained on cc25 multilingual data , finetuned on WMT english romanian translation. |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| DialoGPT | ``DialoGPT-small`` | | 12-layer, 768-hidden, 12-heads, 124M parameters |
|
||||
|
@ -305,9 +305,9 @@ For a list that includes community-uploaded models, refer to `https://huggingfac
|
|||
| MarianMT | ``Helsinki-NLP/opus-mt-{src}-{tgt}`` | | 12-layer, 512-hidden, 8-heads, ~74M parameter Machine translation models. Parameter counts vary depending on vocab size. |
|
||||
| | | | (see `model list <https://huggingface.co/Helsinki-NLP>`_) |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| Longformer | ``longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
|
||||
| Longformer | ``allenai/longformer-base-4096`` | | 12-layer, 768-hidden, 12-heads, ~149M parameters |
|
||||
| | | | Starting from RoBERTa-base checkpoint, trained on documents of max length 4,096 |
|
||||
| +------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
| | ``longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
|
||||
| | ``allenai/longformer-large-4096`` | | 24-layer, 1024-hidden, 16-heads, ~435M parameters |
|
||||
| | | | Starting from RoBERTa-large checkpoint, trained on documents of max length 4,096 |
|
||||
+-------------------+------------------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------+
|
||||
|
|
|
@ -65,13 +65,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
|
@ -389,7 +382,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
|
|
|
@ -34,26 +34,11 @@ from tqdm import tqdm, trange
|
|||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertModel,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertModel,
|
||||
DistilBertTokenizer,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
MMBTConfig,
|
||||
MMBTForClassification,
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMModel,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetModel,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
|
@ -67,23 +52,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertModel, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
|
@ -351,19 +319,12 @@ def main():
|
|||
required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
@ -385,7 +346,7 @@ def main():
|
|||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
|
@ -526,18 +487,14 @@ def main():
|
|||
# Setup model
|
||||
labels = get_mmimdb_labels()
|
||||
num_labels = len(labels)
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
transformer_config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
transformer = model_class.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
transformer = AutoModel.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
|
||||
)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
|
@ -583,13 +540,12 @@ def main():
|
|||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
|
|
|
@ -31,14 +31,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Tenso
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForMultipleChoice,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import WEIGHTS_NAME, AdamW, AutoConfig, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers.modeling_auto import AutoModelForMultipleChoice
|
||||
|
||||
|
||||
try:
|
||||
|
@ -49,12 +43,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
|
@ -492,19 +480,12 @@ def main():
|
|||
required=True,
|
||||
help="SWAG csv for predictions. E.g., val.csv or test.csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
@ -536,9 +517,6 @@ def main():
|
|||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
|
@ -652,13 +630,9 @@ def main():
|
|||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(
|
||||
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
|
||||
)
|
||||
|
||||
|
@ -694,8 +668,8 @@ def main():
|
|||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
|
@ -718,8 +692,8 @@ def main():
|
|||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(checkpoint)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
|
|
@ -67,9 +67,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
|
@ -505,7 +502,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
|
|
@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
|
|||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
|
@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig):
|
|||
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "masked_bert"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||
from emmental import MaskedBertConfig
|
||||
from emmental.modules import MaskedLinear
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
ACT2FN,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertLayerNorm,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from transformers.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert
|
||||
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
|
||||
|
||||
|
@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = MaskedBertConfig
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
|
|
@ -53,8 +53,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
|
||||
|
@ -576,7 +574,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
|
|
|
@ -57,8 +57,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
|
||||
|
@ -673,7 +671,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
|
|
@ -58,8 +58,6 @@ logger = logging.getLogger(__name__)
|
|||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
|
@ -491,7 +489,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
|
|
@ -61,7 +61,6 @@ class BertAbsConfig(PretrainedConfig):
|
|||
the decoder.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -33,14 +33,13 @@ from transformers import BertConfig, BertModel, PreTrainedModel
|
|||
|
||||
MAX_SIZE = 5000
|
||||
|
||||
BERTABS_FINETUNED_MODEL_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://cdn.huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin",
|
||||
}
|
||||
BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [
|
||||
"remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization",
|
||||
]
|
||||
|
||||
|
||||
class BertAbsPreTrainedModel(PreTrainedModel):
|
||||
config_class = BertAbsConfig
|
||||
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
|
||||
load_tf_weights = False
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
|
|
@ -258,7 +258,7 @@ TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recal
|
|||
|
||||
Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py).
|
||||
|
||||
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
|
||||
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
|
||||
|
||||
#### Fine-tuning on XNLI
|
||||
|
||||
|
@ -273,7 +273,6 @@ on a single tesla V100 16GB. The data for XNLI can be downloaded with the follow
|
|||
export XNLI_DIR=/path/to/XNLI
|
||||
|
||||
python run_xnli.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-multilingual-cased \
|
||||
--language de \
|
||||
--train_language en \
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
|
||||
""" Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM).
|
||||
Adapted from `examples/text-classification/run_glue.py`"""
|
||||
|
||||
|
||||
|
@ -32,15 +32,9 @@ from tqdm import tqdm, trange
|
|||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
|
@ -57,16 +51,6 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
|
@ -377,19 +361,12 @@ def main():
|
|||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
|
@ -421,7 +398,7 @@ def main():
|
|||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
|
@ -562,24 +539,23 @@ def main():
|
|||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.model_type = config.model_type
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
|
@ -614,14 +590,13 @@ def main():
|
|||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
|
@ -633,7 +608,7 @@ def main():
|
|||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """
|
||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
|
||||
|
||||
|
||||
import logging
|
||||
|
|
|
@ -159,7 +159,6 @@ if is_torch_available():
|
|||
AutoModelWithLMHead,
|
||||
AutoModelForTokenClassification,
|
||||
AutoModelForMultipleChoice,
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
|
@ -180,7 +179,7 @@ if is_torch_available():
|
|||
BertForTokenClassification,
|
||||
BertForQuestionAnswering,
|
||||
load_tf_weights_in_bert,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BertLayer,
|
||||
)
|
||||
from .modeling_openai import (
|
||||
|
@ -189,7 +188,7 @@ if is_torch_available():
|
|||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
load_tf_weights_in_openai_gpt,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_transfo_xl import (
|
||||
TransfoXLPreTrainedModel,
|
||||
|
@ -197,7 +196,7 @@ if is_torch_available():
|
|||
TransfoXLLMHeadModel,
|
||||
AdaptiveEmbedding,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_gpt2 import (
|
||||
GPT2PreTrainedModel,
|
||||
|
@ -205,9 +204,9 @@ if is_torch_available():
|
|||
GPT2LMHeadModel,
|
||||
GPT2DoubleHeadsModel,
|
||||
load_tf_weights_in_gpt2,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from .modeling_ctrl import CTRLPreTrainedModel, CTRLModel, CTRLLMHeadModel, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from .modeling_xlnet import (
|
||||
XLNetPreTrainedModel,
|
||||
XLNetModel,
|
||||
|
@ -218,7 +217,7 @@ if is_torch_available():
|
|||
XLNetForQuestionAnsweringSimple,
|
||||
XLNetForQuestionAnswering,
|
||||
load_tf_weights_in_xlnet,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_xlm import (
|
||||
XLMPreTrainedModel,
|
||||
|
@ -228,13 +227,13 @@ if is_torch_available():
|
|||
XLMForTokenClassification,
|
||||
XLMForQuestionAnswering,
|
||||
XLMForQuestionAnsweringSimple,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_bart import (
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartForConditionalGeneration,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .tokenization_marian import MarianTokenizer
|
||||
|
@ -245,7 +244,7 @@ if is_torch_available():
|
|||
RobertaForMultipleChoice,
|
||||
RobertaForTokenClassification,
|
||||
RobertaForQuestionAnswering,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_distilbert import (
|
||||
DistilBertPreTrainedModel,
|
||||
|
@ -254,7 +253,7 @@ if is_torch_available():
|
|||
DistilBertForSequenceClassification,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForTokenClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
CamembertForMaskedLM,
|
||||
|
@ -263,7 +262,7 @@ if is_torch_available():
|
|||
CamembertForMultipleChoice,
|
||||
CamembertForTokenClassification,
|
||||
CamembertForQuestionAnswering,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_encoder_decoder import EncoderDecoderModel
|
||||
from .modeling_t5 import (
|
||||
|
@ -271,7 +270,7 @@ if is_torch_available():
|
|||
T5Model,
|
||||
T5ForConditionalGeneration,
|
||||
load_tf_weights_in_t5,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_albert import (
|
||||
AlbertPreTrainedModel,
|
||||
|
@ -282,7 +281,7 @@ if is_torch_available():
|
|||
AlbertForQuestionAnswering,
|
||||
AlbertForTokenClassification,
|
||||
load_tf_weights_in_albert,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_xlm_roberta import (
|
||||
XLMRobertaForMaskedLM,
|
||||
|
@ -290,7 +289,7 @@ if is_torch_available():
|
|||
XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForSequenceClassification,
|
||||
XLMRobertaForTokenClassification,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification
|
||||
|
||||
|
@ -300,7 +299,7 @@ if is_torch_available():
|
|||
FlaubertForSequenceClassification,
|
||||
FlaubertForQuestionAnswering,
|
||||
FlaubertForQuestionAnsweringSimple,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_electra import (
|
||||
|
@ -311,7 +310,7 @@ if is_torch_available():
|
|||
ElectraForSequenceClassification,
|
||||
ElectraModel,
|
||||
load_tf_weights_in_electra,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_reformer import (
|
||||
|
@ -319,7 +318,7 @@ if is_torch_available():
|
|||
ReformerLayer,
|
||||
ReformerModel,
|
||||
ReformerModelWithLMHead,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_longformer import (
|
||||
|
@ -329,7 +328,7 @@ if is_torch_available():
|
|||
LongformerForMultipleChoice,
|
||||
LongformerForTokenClassification,
|
||||
LongformerForQuestionAnswering,
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
|
@ -367,7 +366,6 @@ if is_tf_available():
|
|||
TFAutoModelForQuestionAnswering,
|
||||
TFAutoModelWithLMHead,
|
||||
TFAutoModelForTokenClassification,
|
||||
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_MODEL_MAPPING,
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
TF_MODEL_WITH_LM_HEAD_MAPPING,
|
||||
|
@ -388,7 +386,7 @@ if is_tf_available():
|
|||
TFBertForMultipleChoice,
|
||||
TFBertForTokenClassification,
|
||||
TFBertForQuestionAnswering,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_gpt2 import (
|
||||
|
@ -397,7 +395,7 @@ if is_tf_available():
|
|||
TFGPT2Model,
|
||||
TFGPT2LMHeadModel,
|
||||
TFGPT2DoubleHeadsModel,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_openai import (
|
||||
|
@ -406,7 +404,7 @@ if is_tf_available():
|
|||
TFOpenAIGPTModel,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
TFOpenAIGPTDoubleHeadsModel,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_transfo_xl import (
|
||||
|
@ -414,7 +412,7 @@ if is_tf_available():
|
|||
TFTransfoXLMainLayer,
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TFAdaptiveEmbedding,
|
||||
)
|
||||
|
||||
|
@ -426,7 +424,7 @@ if is_tf_available():
|
|||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_xlm import (
|
||||
|
@ -436,7 +434,7 @@ if is_tf_available():
|
|||
TFXLMWithLMHeadModel,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_xlm_roberta import (
|
||||
|
@ -444,7 +442,7 @@ if is_tf_available():
|
|||
TFXLMRobertaModel,
|
||||
TFXLMRobertaForSequenceClassification,
|
||||
TFXLMRobertaForTokenClassification,
|
||||
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_roberta import (
|
||||
|
@ -455,7 +453,7 @@ if is_tf_available():
|
|||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_camembert import (
|
||||
|
@ -463,14 +461,14 @@ if is_tf_available():
|
|||
TFCamembertForMaskedLM,
|
||||
TFCamembertForSequenceClassification,
|
||||
TFCamembertForTokenClassification,
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_flaubert import (
|
||||
TFFlaubertModel,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
TFFlaubertForSequenceClassification,
|
||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_distilbert import (
|
||||
|
@ -481,14 +479,14 @@ if is_tf_available():
|
|||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_ctrl import (
|
||||
TFCTRLPreTrainedModel,
|
||||
TFCTRLModel,
|
||||
TFCTRLLMHeadModel,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_albert import (
|
||||
|
@ -500,14 +498,14 @@ if is_tf_available():
|
|||
TFAlbertForMultipleChoice,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_t5 import (
|
||||
TFT5PreTrainedModel,
|
||||
TFT5Model,
|
||||
TFT5ForConditionalGeneration,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
from .modeling_tf_electra import (
|
||||
|
@ -516,7 +514,7 @@ if is_tf_available():
|
|||
TFElectraForPreTraining,
|
||||
TFElectraForMaskedLM,
|
||||
TFElectraForTokenClassification,
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
# Optimization
|
||||
|
|
|
@ -32,7 +32,7 @@ ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
|
||||
class AlbertConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.AlbertModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.AlbertModel`.
|
||||
It is used to instantiate an ALBERT model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the ALBERT `xxlarge <https://huggingface.co/albert-xxlarge-v2>`__ architecture.
|
||||
|
@ -97,13 +97,8 @@ class AlbertConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "albert"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -113,12 +113,12 @@ class AutoConfig:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def for_model(cls, model_type, *args, **kwargs):
|
||||
for pattern, config_class in CONFIG_MAPPING.items():
|
||||
if pattern in model_type:
|
||||
return config_class(*args, **kwargs)
|
||||
def for_model(cls, model_type: str, *args, **kwargs):
|
||||
if model_type in CONFIG_MAPPING:
|
||||
config_class = CONFIG_MAPPING[model_type]
|
||||
return config_class(*args, **kwargs)
|
||||
raise ValueError(
|
||||
"Unrecognized model identifier in {}. Should contain one of {}".format(
|
||||
"Unrecognized model identifier: {}. Should contain one of {}".format(
|
||||
model_type, ", ".join(CONFIG_MAPPING.keys())
|
||||
)
|
||||
)
|
||||
|
@ -130,24 +130,24 @@ class AutoConfig:
|
|||
|
||||
The configuration class to instantiate is selected
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
- contains `t5`: :class:`~transformers.T5Config` (T5 model)
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
|
||||
- contains `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
|
||||
- contains `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
|
||||
- contains `bert`: :class:`~transformers.BertConfig` (Bert model)
|
||||
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
|
||||
- contains `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMConfig` (XLM model)
|
||||
- contains `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
|
||||
- contains `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
|
||||
- contains `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5Config` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.DistilBertConfig` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertConfig` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertConfig` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaConfig` (XLM-RoBERTa model)
|
||||
- `longformer`: :class:`~transformers.LongformerConfig` (Longformer model)
|
||||
- `roberta`: :class:`~transformers.RobertaConfig` (RoBERTa model)
|
||||
- `reformer`: :class:`~transformers.ReformerConfig` (Reformer model)
|
||||
- `bert`: :class:`~transformers.BertConfig` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.OpenAIGPTConfig` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.GPT2Config` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TransfoXLConfig` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.XLNetConfig` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.XLMConfig` (XLM model)
|
||||
- `ctrl` : :class:`~transformers.CTRLConfig` (CTRL model)
|
||||
- `flaubert` : :class:`~transformers.FlaubertConfig` (Flaubert model)
|
||||
- `electra` : :class:`~transformers.ElectraConfig` (ELECTRA model)
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (:obj:`string`):
|
||||
|
@ -193,9 +193,7 @@ class AutoConfig:
|
|||
assert unused_kwargs == {'foo': False}
|
||||
|
||||
"""
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
pretrained_model_name_or_path, pretrained_config_archive_map=ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, **kwargs
|
||||
)
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
if "model_type" in config_dict:
|
||||
config_class = CONFIG_MAPPING[config_dict["model_type"]]
|
||||
|
|
|
@ -23,11 +23,11 @@ from .configuration_utils import PretrainedConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
BART_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
|
||||
"bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
"bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||
"mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
"facebook/bart-large": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large/config.json",
|
||||
"facebook/bart-large-mnli": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-mnli/config.json",
|
||||
"facebook/bart-large-cnn": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-cnn/config.json",
|
||||
"facebook/bart-large-xsum": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/bart-large-xsum/config.json",
|
||||
"facebook/mbart-large-en-ro": "https://s3.amazonaws.com/models.huggingface.co/bert/facebook/mbart-large-en-ro/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
@ -36,7 +36,6 @@ class BartConfig(PretrainedConfig):
|
|||
Configuration class for Bart. Parameters are renamed from the fairseq implementation
|
||||
"""
|
||||
model_type = "bart"
|
||||
pretrained_config_archive_map = BART_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -39,13 +39,14 @@ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
||||
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
||||
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
||||
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
|
||||
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
|
||||
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
|
||||
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
||||
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
||||
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
|
||||
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/config.json",
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/config.json",
|
||||
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/config.json",
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/config.json",
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
||||
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
|
||||
# See all BERT models at https://huggingface.co/models?filter=bert
|
||||
}
|
||||
|
||||
|
||||
|
@ -102,12 +103,7 @@ class BertConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "bert"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -36,5 +36,4 @@ class CamembertConfig(RobertaConfig):
|
|||
superclass for the appropriate documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "camembert"
|
||||
|
|
|
@ -27,7 +27,7 @@ CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf
|
|||
|
||||
class CTRLConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.CTRLModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.CTRLModel`.
|
||||
It is used to instantiate an CTRL model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the `ctrl <https://huggingface.co/ctrl>`__ architecture from SalesForce.
|
||||
|
@ -76,13 +76,8 @@ class CTRLConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "ctrl"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -90,12 +90,7 @@ class DistilBertConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "distilbert"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -89,12 +89,7 @@ class ElectraConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "electra"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -23,10 +23,10 @@ from .configuration_xlm import XLMConfig
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"flaubert-small-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
|
||||
"flaubert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/config.json",
|
||||
"flaubert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/config.json",
|
||||
"flaubert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/config.json",
|
||||
"flaubert/flaubert_small_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_small_cased/config.json",
|
||||
"flaubert/flaubert_base_uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_uncased/config.json",
|
||||
"flaubert/flaubert_base_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_base_cased/config.json",
|
||||
"flaubert/flaubert_large_cased": "https://s3.amazonaws.com/models.huggingface.co/bert/flaubert/flaubert_large_cased/config.json",
|
||||
}
|
||||
|
||||
|
||||
|
@ -142,7 +142,6 @@ class FlaubertConfig(XLMConfig):
|
|||
text in a given language.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "flaubert"
|
||||
|
||||
def __init__(self, layerdrop=0.0, pre_norm=False, pad_token_id=2, bos_token_id=0, **kwargs):
|
||||
|
|
|
@ -110,13 +110,8 @@ class GPT2Config(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "gpt2"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -33,7 +33,7 @@ LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
|
||||
class LongformerConfig(RobertaConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.LongformerModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.LongformerModel`.
|
||||
It is used to instantiate an Longformer model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the RoBERTa `roberta-base <https://huggingface.co/roberta-base>`__ architecture with a sequence length 4,096.
|
||||
|
@ -59,12 +59,7 @@ class LongformerConfig(RobertaConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "longformer"
|
||||
|
||||
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
|
||||
|
|
|
@ -18,10 +18,9 @@ from .configuration_bart import BartConfig
|
|||
|
||||
|
||||
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||
"marian-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json",
|
||||
"Helsinki-NLP/opus-mt-en-de": "https://s3.amazonaws.com/models.huggingface.co/bert/Helsinki-NLP/opus-mt-en-de/config.json",
|
||||
}
|
||||
|
||||
|
||||
class MarianConfig(BartConfig):
|
||||
model_type = "marian"
|
||||
pretrained_config_archive_map = PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
|
|
|
@ -30,7 +30,7 @@ OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
|
||||
class OpenAIGPTConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.OpenAIGPTModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.OpenAIGPTModel`.
|
||||
It is used to instantiate an GPT model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the `GPT <https://huggingface.co/openai-gpt>`__ architecture from OpenAI.
|
||||
|
@ -108,13 +108,8 @@ class OpenAIGPTConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "openai-gpt"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -135,12 +135,7 @@ class ReformerConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "reformer"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -35,7 +35,7 @@ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
|
||||
class RobertaConfig(BertConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.RobertaModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.RobertaModel`.
|
||||
It is used to instantiate an RoBERTa model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
|
||||
|
@ -59,12 +59,7 @@ class RobertaConfig(BertConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "roberta"
|
||||
|
||||
def __init__(self, pad_token_id=1, bos_token_id=0, eos_token_id=2, **kwargs):
|
||||
|
|
|
@ -59,7 +59,6 @@ class T5Config(PretrainedConfig):
|
|||
initializer_factor: A factor for initializing all weight matrices (should be kept to 1.0, used for initialization testing).
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
pretrained_config_archive_map = T5_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "t5"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -30,7 +30,7 @@ TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
|||
|
||||
class TransfoXLConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of an :class:`~transformers.TransfoXLModel`.
|
||||
This is the configuration class to store the configuration of a :class:`~transformers.TransfoXLModel`.
|
||||
It is used to instantiate a Transformer XL model according to the specified arguments, defining the model
|
||||
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
||||
the `Transformer XL <https://huggingface.co/transfo-xl-wt103>`__ architecture.
|
||||
|
@ -110,13 +110,8 @@ class TransfoXLConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "transfo-xl"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -20,7 +20,7 @@ import copy
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from .file_utils import CONFIG_NAME, cached_path, hf_bucket_url, is_remote_url
|
||||
|
||||
|
@ -37,7 +37,6 @@ class PretrainedConfig(object):
|
|||
It only affects the model's configuration.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``pretrained_config_archive_map``: a python ``dict`` with `shortcut names` (string) as keys and `url` (string) of associated pretrained model configurations as values.
|
||||
- ``model_type``: a string that identifies the model type, that we serialize into the JSON file, and that we use to recreate the correct object in :class:`~transformers.AutoConfig`.
|
||||
|
||||
Args:
|
||||
|
@ -52,7 +51,6 @@ class PretrainedConfig(object):
|
|||
torchscript (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Is the model used with Torchscript (for PyTorch models).
|
||||
"""
|
||||
pretrained_config_archive_map: Dict[str, str] = {}
|
||||
model_type: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
|
@ -204,9 +202,7 @@ class PretrainedConfig(object):
|
|||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: str, pretrained_config_archive_map: Optional[Dict] = None, **kwargs
|
||||
) -> Tuple[Dict, Dict]:
|
||||
def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs) -> Tuple[Dict, Dict]:
|
||||
"""
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
|
||||
for instantiating a Config using `from_dict`.
|
||||
|
@ -214,8 +210,6 @@ class PretrainedConfig(object):
|
|||
Parameters:
|
||||
pretrained_model_name_or_path (:obj:`string`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
pretrained_config_archive_map: (:obj:`Dict[str, str]`, `optional`) Dict:
|
||||
A map of `shortcut names` to `url`. By default, will use the current class attribute.
|
||||
|
||||
Returns:
|
||||
:obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
|
||||
|
@ -227,12 +221,7 @@ class PretrainedConfig(object):
|
|||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
|
||||
if pretrained_config_archive_map is None:
|
||||
pretrained_config_archive_map = cls.pretrained_config_archive_map
|
||||
|
||||
if pretrained_model_name_or_path in pretrained_config_archive_map:
|
||||
config_file = pretrained_config_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
|
@ -255,21 +244,11 @@ class PretrainedConfig(object):
|
|||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in pretrained_config_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained model configuration file.".format(
|
||||
config_file
|
||||
)
|
||||
else:
|
||||
msg = (
|
||||
"Can't load '{}'. Make sure that:\n\n"
|
||||
"- '{}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
"- or '{}' is the correct path to a directory containing a '{}' file\n\n".format(
|
||||
pretrained_model_name_or_path,
|
||||
pretrained_model_name_or_path,
|
||||
pretrained_model_name_or_path,
|
||||
CONFIG_NAME,
|
||||
)
|
||||
)
|
||||
msg = (
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a {CONFIG_NAME} file\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
|
|
|
@ -152,13 +152,8 @@ class XLMConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "xlm"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -39,5 +39,4 @@ class XLMRobertaConfig(RobertaConfig):
|
|||
superclass for the appropriate documentation alongside usage examples.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "xlm-roberta"
|
||||
|
|
|
@ -122,13 +122,8 @@ class XLNetConfig(PretrainedConfig):
|
|||
|
||||
# Accessing the model configuration
|
||||
configuration = model.config
|
||||
|
||||
Attributes:
|
||||
pretrained_config_archive_map (Dict[str, str]):
|
||||
A dictionary containing all the available pre-trained checkpoints.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "xlnet"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -32,6 +32,7 @@ from transformers import (
|
|||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
WEIGHTS_NAME,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
|
@ -70,6 +71,7 @@ from transformers import (
|
|||
XLMRobertaConfig,
|
||||
XLNetConfig,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_torch_available,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
)
|
||||
|
@ -82,261 +84,103 @@ if is_torch_available():
|
|||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMRobertaForMaskedLM,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CamembertForMaskedLM,
|
||||
CamembertForSequenceClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FlaubertWithLMHeadModel,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForSequenceClassification,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForPreTraining,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5ForConditionalGeneration,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
)
|
||||
else:
|
||||
(
|
||||
BertForPreTraining,
|
||||
BertForQuestionAnswering,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMRobertaForMaskedLM,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CamembertForMaskedLM,
|
||||
CamembertForSequenceClassification,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FlaubertWithLMHeadModel,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForPreTraining,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5ForConditionalGeneration,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
) = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (
|
||||
BertConfig,
|
||||
TFBertForPreTraining,
|
||||
BertForPreTraining,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert": (BertConfig, TFBertForPreTraining, BertForPreTraining, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": (
|
||||
BertConfig,
|
||||
TFBertForQuestionAnswering,
|
||||
BertForQuestionAnswering,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": (
|
||||
BertConfig,
|
||||
TFBertForQuestionAnswering,
|
||||
BertForQuestionAnswering,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"bert-base-cased-finetuned-mrpc": (
|
||||
BertConfig,
|
||||
TFBertForSequenceClassification,
|
||||
BertForSequenceClassification,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"gpt2": (
|
||||
GPT2Config,
|
||||
TFGPT2LMHeadModel,
|
||||
GPT2LMHeadModel,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"xlnet": (
|
||||
XLNetConfig,
|
||||
TFXLNetLMHeadModel,
|
||||
XLNetLMHeadModel,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"xlm": (
|
||||
XLMConfig,
|
||||
TFXLMWithLMHeadModel,
|
||||
XLMWithLMHeadModel,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"gpt2": (GPT2Config, TFGPT2LMHeadModel, GPT2LMHeadModel, GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"xlnet": (XLNetConfig, TFXLNetLMHeadModel, XLNetLMHeadModel, XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"xlm": (XLMConfig, TFXLMWithLMHeadModel, XLMWithLMHeadModel, XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"xlm-roberta": (
|
||||
XLMRobertaConfig,
|
||||
TFXLMRobertaForMaskedLM,
|
||||
XLMRobertaForMaskedLM,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"transfo-xl": (
|
||||
TransfoXLConfig,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TransfoXLLMHeadModel,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"openai-gpt": (
|
||||
OpenAIGPTConfig,
|
||||
TFOpenAIGPTLMHeadModel,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"roberta": (
|
||||
RobertaConfig,
|
||||
TFRobertaForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"roberta": (RobertaConfig, TFRobertaForMaskedLM, RobertaForMaskedLM, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"roberta-large-mnli": (
|
||||
RobertaConfig,
|
||||
TFRobertaForSequenceClassification,
|
||||
RobertaForSequenceClassification,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"camembert": (
|
||||
CamembertConfig,
|
||||
TFCamembertForMaskedLM,
|
||||
CamembertForMaskedLM,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"flaubert": (
|
||||
FlaubertConfig,
|
||||
TFFlaubertWithLMHeadModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"distilbert": (
|
||||
DistilBertConfig,
|
||||
TFDistilBertForMaskedLM,
|
||||
DistilBertForMaskedLM,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"distilbert-base-distilled-squad": (
|
||||
DistilBertConfig,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
DistilBertForQuestionAnswering,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"ctrl": (
|
||||
CTRLConfig,
|
||||
TFCTRLLMHeadModel,
|
||||
CTRLLMHeadModel,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"albert": (
|
||||
AlbertConfig,
|
||||
TFAlbertForPreTraining,
|
||||
AlbertForPreTraining,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"t5": (
|
||||
T5Config,
|
||||
TFT5ForConditionalGeneration,
|
||||
T5ForConditionalGeneration,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"electra": (
|
||||
ElectraConfig,
|
||||
TFElectraForPreTraining,
|
||||
ElectraForPreTraining,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
),
|
||||
"ctrl": (CTRLConfig, TFCTRLLMHeadModel, CTRLLMHeadModel, CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"albert": (AlbertConfig, TFAlbertForPreTraining, AlbertForPreTraining, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"t5": (T5Config, TFT5ForConditionalGeneration, T5ForConditionalGeneration, T5_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
"electra": (ElectraConfig, TFElectraForPreTraining, ElectraForPreTraining, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP,),
|
||||
}
|
||||
|
||||
|
||||
|
@ -346,7 +190,7 @@ def convert_pt_checkpoint_to_tf(
|
|||
if model_type not in MODEL_CLASSES:
|
||||
raise ValueError("Unrecognized model type, should be one of {}.".format(list(MODEL_CLASSES.keys())))
|
||||
|
||||
config_class, model_class, pt_model_class, aws_model_maps, aws_config_map = MODEL_CLASSES[model_type]
|
||||
config_class, model_class, pt_model_class, aws_config_map = MODEL_CLASSES[model_type]
|
||||
|
||||
# Initialise TF model
|
||||
if config_file in aws_config_map:
|
||||
|
@ -358,10 +202,9 @@ def convert_pt_checkpoint_to_tf(
|
|||
tf_model = model_class(config)
|
||||
|
||||
# Load weights from tf checkpoint
|
||||
if pytorch_checkpoint_path in aws_model_maps:
|
||||
pytorch_checkpoint_path = cached_path(
|
||||
aws_model_maps[pytorch_checkpoint_path], force_download=not use_cached_models
|
||||
)
|
||||
if pytorch_checkpoint_path in aws_config_map.keys():
|
||||
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
|
||||
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
|
||||
# Load PyTorch checkpoint in tf2 model:
|
||||
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
|
||||
|
||||
|
|
|
@ -31,16 +31,17 @@ from .modeling_utils import PreTrainedModel
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"albert-base-v1": "https://cdn.huggingface.co/albert-base-v1-pytorch_model.bin",
|
||||
"albert-large-v1": "https://cdn.huggingface.co/albert-large-v1-pytorch_model.bin",
|
||||
"albert-xlarge-v1": "https://cdn.huggingface.co/albert-xlarge-v1-pytorch_model.bin",
|
||||
"albert-xxlarge-v1": "https://cdn.huggingface.co/albert-xxlarge-v1-pytorch_model.bin",
|
||||
"albert-base-v2": "https://cdn.huggingface.co/albert-base-v2-pytorch_model.bin",
|
||||
"albert-large-v2": "https://cdn.huggingface.co/albert-large-v2-pytorch_model.bin",
|
||||
"albert-xlarge-v2": "https://cdn.huggingface.co/albert-xlarge-v2-pytorch_model.bin",
|
||||
"albert-xxlarge-v2": "https://cdn.huggingface.co/albert-xxlarge-v2-pytorch_model.bin",
|
||||
}
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"albert-base-v1",
|
||||
"albert-large-v1",
|
||||
"albert-xlarge-v1",
|
||||
"albert-xxlarge-v1",
|
||||
"albert-base-v2",
|
||||
"albert-large-v2",
|
||||
"albert-xlarge-v2",
|
||||
"albert-xxlarge-v2",
|
||||
# See all ALBERT models at https://huggingface.co/models?filter=albert
|
||||
]
|
||||
|
||||
|
||||
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
||||
|
@ -365,7 +366,6 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = AlbertConfig
|
||||
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "albert"
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
@ -439,7 +439,6 @@ ALBERT_INPUTS_DOCSTRING = r"""
|
|||
class AlbertModel(AlbertPreTrainedModel):
|
||||
|
||||
config_class = AlbertConfig
|
||||
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_albert
|
||||
base_model_prefix = "albert"
|
||||
|
||||
|
|
|
@ -43,7 +43,6 @@ from .configuration_auto import (
|
|||
from .configuration_marian import MarianConfig
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .modeling_albert import (
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AlbertForMaskedLM,
|
||||
AlbertForPreTraining,
|
||||
AlbertForQuestionAnswering,
|
||||
|
@ -51,14 +50,8 @@ from .modeling_albert import (
|
|||
AlbertForTokenClassification,
|
||||
AlbertModel,
|
||||
)
|
||||
from .modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
)
|
||||
from .modeling_bart import BartForConditionalGeneration, BartForSequenceClassification, BartModel
|
||||
from .modeling_bert import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertForMaskedLM,
|
||||
BertForMultipleChoice,
|
||||
BertForPreTraining,
|
||||
|
@ -68,16 +61,14 @@ from .modeling_bert import (
|
|||
BertModel,
|
||||
)
|
||||
from .modeling_camembert import (
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CamembertForMaskedLM,
|
||||
CamembertForMultipleChoice,
|
||||
CamembertForSequenceClassification,
|
||||
CamembertForTokenClassification,
|
||||
CamembertModel,
|
||||
)
|
||||
from .modeling_ctrl import CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel, CTRLModel
|
||||
from .modeling_ctrl import CTRLLMHeadModel, CTRLModel
|
||||
from .modeling_distilbert import (
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertForQuestionAnswering,
|
||||
DistilBertForSequenceClassification,
|
||||
|
@ -85,7 +76,6 @@ from .modeling_distilbert import (
|
|||
DistilBertModel,
|
||||
)
|
||||
from .modeling_electra import (
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ElectraForMaskedLM,
|
||||
ElectraForPreTraining,
|
||||
ElectraForSequenceClassification,
|
||||
|
@ -94,15 +84,13 @@ from .modeling_electra import (
|
|||
)
|
||||
from .modeling_encoder_decoder import EncoderDecoderModel
|
||||
from .modeling_flaubert import (
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FlaubertForQuestionAnsweringSimple,
|
||||
FlaubertForSequenceClassification,
|
||||
FlaubertModel,
|
||||
FlaubertWithLMHeadModel,
|
||||
)
|
||||
from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model
|
||||
from .modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
||||
from .modeling_longformer import (
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForMultipleChoice,
|
||||
LongformerForQuestionAnswering,
|
||||
|
@ -111,10 +99,9 @@ from .modeling_longformer import (
|
|||
LongformerModel,
|
||||
)
|
||||
from .modeling_marian import MarianMTModel
|
||||
from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_openai import OpenAIGPTLMHeadModel, OpenAIGPTModel
|
||||
from .modeling_reformer import ReformerModel, ReformerModelWithLMHead
|
||||
from .modeling_roberta import (
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForMultipleChoice,
|
||||
RobertaForQuestionAnswering,
|
||||
|
@ -122,10 +109,9 @@ from .modeling_roberta import (
|
|||
RobertaForTokenClassification,
|
||||
RobertaModel,
|
||||
)
|
||||
from .modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5ForConditionalGeneration, T5Model
|
||||
from .modeling_transfo_xl import TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP, TransfoXLLMHeadModel, TransfoXLModel
|
||||
from .modeling_t5 import T5ForConditionalGeneration, T5Model
|
||||
from .modeling_transfo_xl import TransfoXLLMHeadModel, TransfoXLModel
|
||||
from .modeling_xlm import (
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMForQuestionAnsweringSimple,
|
||||
XLMForSequenceClassification,
|
||||
XLMForTokenClassification,
|
||||
|
@ -133,7 +119,6 @@ from .modeling_xlm import (
|
|||
XLMWithLMHeadModel,
|
||||
)
|
||||
from .modeling_xlm_roberta import (
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLMRobertaForMaskedLM,
|
||||
XLMRobertaForMultipleChoice,
|
||||
XLMRobertaForSequenceClassification,
|
||||
|
@ -141,7 +126,6 @@ from .modeling_xlm_roberta import (
|
|||
XLMRobertaModel,
|
||||
)
|
||||
from .modeling_xlnet import (
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNetForMultipleChoice,
|
||||
XLNetForQuestionAnsweringSimple,
|
||||
XLNetForSequenceClassification,
|
||||
|
@ -154,30 +138,6 @@ from .modeling_xlnet import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
(T5Config, T5Model),
|
||||
|
@ -372,29 +332,26 @@ class AutoModel:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5Model` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
|
||||
- `longformer` :class:`~transformers.LongformerModel` (Longformer model)
|
||||
- `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
|
||||
- `bert`: :class:`~transformers.BertModel` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.XLMModel` (XLM model)
|
||||
- `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert model)
|
||||
- `electra`: :class:`~transformers.ElectraModel` (Electra model)
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: :class:`~transformers.T5Model` (T5 model)
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertModel` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertModel` (ALBERT model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertModel` (CamemBERT model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaModel` (XLM-RoBERTa model)
|
||||
- contains `longformer` :class:`~transformers.LongformerModel` (Longformer model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaModel` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.BertModel` (Bert model)
|
||||
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTModel` (OpenAI GPT model)
|
||||
- contains `gpt2`: :class:`~transformers.GPT2Model` (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: :class:`~transformers.TransfoXLModel` (Transformer-XL model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetModel` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraModel` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
@ -528,26 +485,23 @@ class AutoModelForPreTraining:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||
- contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
|
||||
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||
- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- `bert`: :class:`~transformers.BertForPreTraining` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- `electra`: :class:`~transformers.ElectraForPreTraining` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -679,26 +633,23 @@ class AutoModelWithLMHead:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||
- contains `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
|
||||
- contains `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- contains `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.T5ModelWithLMHead` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.DistilBertForMaskedLM` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForMaskedLM` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertForMaskedLM` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaForMaskedLM` (XLM-RoBERTa model)
|
||||
- `longformer`: :class:`~transformers.LongformerForMaskedLM` (Longformer model)
|
||||
- `roberta`: :class:`~transformers.RobertaForMaskedLM` (RoBERTa model)
|
||||
- `bert`: :class:`~transformers.BertForMaskedLM` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.OpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.GPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.XLNetLMHeadModel` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.XLMWithLMHeadModel` (XLM model)
|
||||
- `ctrl`: :class:`~transformers.CTRLLMHeadModel` (Salesforce CTRL model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertWithLMHeadModel` (Flaubert model)
|
||||
- `electra`: :class:`~transformers.ElectraForMaskedLM` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -830,18 +781,15 @@ class AutoModelForSequenceClassification:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: :class:`~transformers.DistilBertForSequenceClassification` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForSequenceClassification` (ALBERT model)
|
||||
- `camembert`: :class:`~transformers.CamembertForSequenceClassification` (CamemBERT model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaForSequenceClassification` (XLM-RoBERTa model)
|
||||
- `roberta`: :class:`~transformers.RobertaForSequenceClassification` (RoBERTa model)
|
||||
- `bert`: :class:`~transformers.BertForSequenceClassification` (Bert model)
|
||||
- `xlnet`: :class:`~transformers.XLNetForSequenceClassification` (XLNet model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertForSequenceClassification` (Flaubert model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -979,16 +927,13 @@ class AutoModelForQuestionAnswering:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
|
||||
- contains `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
|
||||
- contains `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: :class:`~transformers.DistilBertForQuestionAnswering` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.AlbertForQuestionAnswering` (ALBERT model)
|
||||
- `bert`: :class:`~transformers.BertForQuestionAnswering` (Bert model)
|
||||
- `xlnet`: :class:`~transformers.XLNetForQuestionAnswering` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.XLMForQuestionAnswering` (XLM model)
|
||||
- `flaubert`: :class:`~transformers.FlaubertForQuestionAnswering` (XLM model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -1127,18 +1072,15 @@ class AutoModelForTokenClassification:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
|
||||
- contains `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
|
||||
- contains `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
|
||||
- contains `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
|
||||
- contains `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
|
||||
- contains `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
|
||||
- contains `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
|
||||
- contains `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: :class:`~transformers.DistilBertForTokenClassification` (DistilBERT model)
|
||||
- `xlm`: :class:`~transformers.XLMForTokenClassification` (XLM model)
|
||||
- `xlm-roberta`: :class:`~transformers.XLMRobertaForTokenClassification` (XLM-RoBERTa?Para model)
|
||||
- `camembert`: :class:`~transformers.CamembertForTokenClassification` (Camembert model)
|
||||
- `bert`: :class:`~transformers.BertForTokenClassification` (Bert model)
|
||||
- `xlnet`: :class:`~transformers.XLNetForTokenClassification` (XLNet model)
|
||||
- `roberta`: :class:`~transformers.RobertaForTokenClassification` (Roberta model)
|
||||
- `electra`: :class:`~transformers.ElectraForTokenClassification` (Electra model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
|
|
@ -32,13 +32,15 @@ from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bart-large": "https://cdn.huggingface.co/facebook/bart-large/pytorch_model.bin",
|
||||
"bart-large-mnli": "https://cdn.huggingface.co/facebook/bart-large-mnli/pytorch_model.bin",
|
||||
"bart-large-cnn": "https://cdn.huggingface.co/facebook/bart-large-cnn/pytorch_model.bin",
|
||||
"bart-large-xsum": "https://cdn.huggingface.co/facebook/bart-large-xsum/pytorch_model.bin",
|
||||
"mbart-large-en-ro": "https://cdn.huggingface.co/facebook/mbart-large-en-ro/pytorch_model.bin",
|
||||
}
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/bart-large",
|
||||
"facebook/bart-large-mnli",
|
||||
"facebook/bart-large-cnn",
|
||||
"facebook/bart-large-xsum",
|
||||
"facebook/mbart-large-en-ro",
|
||||
# See all BART models at https://huggingface.co/models?filter=bart
|
||||
]
|
||||
|
||||
|
||||
BART_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -118,7 +120,6 @@ def _prepare_bart_decoder_inputs(
|
|||
class PretrainedBartModel(PreTrainedModel):
|
||||
config_class = BartConfig
|
||||
base_model_prefix = "model"
|
||||
pretrained_model_archive_map = BART_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.init_std
|
||||
|
|
|
@ -32,30 +32,31 @@ from .modeling_utils import PreTrainedModel, prune_linear_layer
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bert-base-uncased": "https://cdn.huggingface.co/bert-base-uncased-pytorch_model.bin",
|
||||
"bert-large-uncased": "https://cdn.huggingface.co/bert-large-uncased-pytorch_model.bin",
|
||||
"bert-base-cased": "https://cdn.huggingface.co/bert-base-cased-pytorch_model.bin",
|
||||
"bert-large-cased": "https://cdn.huggingface.co/bert-large-cased-pytorch_model.bin",
|
||||
"bert-base-multilingual-uncased": "https://cdn.huggingface.co/bert-base-multilingual-uncased-pytorch_model.bin",
|
||||
"bert-base-multilingual-cased": "https://cdn.huggingface.co/bert-base-multilingual-cased-pytorch_model.bin",
|
||||
"bert-base-chinese": "https://cdn.huggingface.co/bert-base-chinese-pytorch_model.bin",
|
||||
"bert-base-german-cased": "https://cdn.huggingface.co/bert-base-german-cased-pytorch_model.bin",
|
||||
"bert-large-uncased-whole-word-masking": "https://cdn.huggingface.co/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
||||
"bert-large-cased-whole-word-masking": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://cdn.huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
||||
"bert-base-cased-finetuned-mrpc": "https://cdn.huggingface.co/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
||||
"bert-base-german-dbmdz-cased": "https://cdn.huggingface.co/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
||||
"bert-base-german-dbmdz-uncased": "https://cdn.huggingface.co/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
||||
"bert-base-japanese": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese/pytorch_model.bin",
|
||||
"bert-base-japanese-whole-word-masking": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/pytorch_model.bin",
|
||||
"bert-base-japanese-char": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-char/pytorch_model.bin",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/pytorch_model.bin",
|
||||
"bert-base-finnish-cased-v1": "https://cdn.huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
||||
"bert-base-finnish-uncased-v1": "https://cdn.huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
||||
"bert-base-dutch-cased": "https://cdn.huggingface.co/wietsedv/bert-base-dutch-cased/pytorch_model.bin",
|
||||
}
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"bert-base-uncased",
|
||||
"bert-large-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-cased",
|
||||
"bert-base-multilingual-uncased",
|
||||
"bert-base-multilingual-cased",
|
||||
"bert-base-chinese",
|
||||
"bert-base-german-cased",
|
||||
"bert-large-uncased-whole-word-masking",
|
||||
"bert-large-cased-whole-word-masking",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad",
|
||||
"bert-base-cased-finetuned-mrpc",
|
||||
"bert-base-german-dbmdz-cased",
|
||||
"bert-base-german-dbmdz-uncased",
|
||||
"cl-tohoku/bert-base-japanese",
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
||||
"cl-tohoku/bert-base-japanese-char",
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
||||
"TurkuNLP/bert-base-finnish-cased-v1",
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1",
|
||||
"wietsedv/bert-base-dutch-cased",
|
||||
# See all BERT models at https://huggingface.co/models?filter=bert
|
||||
]
|
||||
|
||||
|
||||
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
||||
|
@ -513,7 +514,6 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
|
|
@ -31,11 +31,12 @@ from .modeling_roberta import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"camembert-base": "https://cdn.huggingface.co/camembert-base-pytorch_model.bin",
|
||||
"umberto-commoncrawl-cased-v1": "https://cdn.huggingface.co/Musixmatch/umberto-commoncrawl-cased-v1/pytorch_model.bin",
|
||||
"umberto-wikipedia-uncased-v1": "https://cdn.huggingface.co/Musixmatch/umberto-wikipedia-uncased-v1/pytorch_model.bin",
|
||||
}
|
||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"camembert-base",
|
||||
"Musixmatch/umberto-commoncrawl-cased-v1",
|
||||
"Musixmatch/umberto-wikipedia-uncased-v1",
|
||||
# See all CamemBERT models at https://huggingface.co/models?filter=camembert
|
||||
]
|
||||
|
||||
CAMEMBERT_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -62,7 +63,6 @@ class CamembertModel(RobertaModel):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -75,7 +75,6 @@ class CamembertForMaskedLM(RobertaForMaskedLM):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -90,7 +89,6 @@ class CamembertForSequenceClassification(RobertaForSequenceClassification):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -105,7 +103,6 @@ class CamembertForMultipleChoice(RobertaForMultipleChoice):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -120,7 +117,6 @@ class CamembertForTokenClassification(RobertaForTokenClassification):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -135,4 +131,3 @@ class CamembertForQuestionAnswering(RobertaForQuestionAnswering):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
|
|
@ -30,7 +30,10 @@ from .modeling_utils import Conv1D, PreTrainedModel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://storage.googleapis.com/sf-ctrl/pytorch/seqlen256_v1.bin"}
|
||||
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ctrl"
|
||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||
]
|
||||
|
||||
|
||||
def angle_defn(pos, i, d_model_size):
|
||||
|
@ -178,7 +181,6 @@ class CTRLPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = CTRLConfig
|
||||
pretrained_model_archive_map = CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
|
|
@ -36,15 +36,16 @@ from .modeling_utils import PreTrainedModel, prune_linear_layer
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"distilbert-base-uncased": "https://cdn.huggingface.co/distilbert-base-uncased-pytorch_model.bin",
|
||||
"distilbert-base-uncased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-uncased-distilled-squad-pytorch_model.bin",
|
||||
"distilbert-base-cased": "https://cdn.huggingface.co/distilbert-base-cased-pytorch_model.bin",
|
||||
"distilbert-base-cased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-pytorch_model.bin",
|
||||
"distilbert-base-german-cased": "https://cdn.huggingface.co/distilbert-base-german-cased-pytorch_model.bin",
|
||||
"distilbert-base-multilingual-cased": "https://cdn.huggingface.co/distilbert-base-multilingual-cased-pytorch_model.bin",
|
||||
"distilbert-base-uncased-finetuned-sst-2-english": "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-pytorch_model.bin",
|
||||
}
|
||||
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"distilbert-base-uncased",
|
||||
"distilbert-base-uncased-distilled-squad",
|
||||
"distilbert-base-cased",
|
||||
"distilbert-base-cased-distilled-squad",
|
||||
"distilbert-base-german-cased",
|
||||
"distilbert-base-multilingual-cased",
|
||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||
# See all DistilBERT models at https://huggingface.co/models?filter=distilbert
|
||||
]
|
||||
|
||||
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
|
@ -327,7 +328,6 @@ class DistilBertPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = DistilBertConfig
|
||||
pretrained_model_archive_map = DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
|
|
|
@ -14,14 +14,15 @@ from .modeling_bert import BertEmbeddings, BertEncoder, BertLayerNorm, BertPreTr
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"google/electra-small-generator": "https://cdn.huggingface.co/google/electra-small-generator/pytorch_model.bin",
|
||||
"google/electra-base-generator": "https://cdn.huggingface.co/google/electra-base-generator/pytorch_model.bin",
|
||||
"google/electra-large-generator": "https://cdn.huggingface.co/google/electra-large-generator/pytorch_model.bin",
|
||||
"google/electra-small-discriminator": "https://cdn.huggingface.co/google/electra-small-discriminator/pytorch_model.bin",
|
||||
"google/electra-base-discriminator": "https://cdn.huggingface.co/google/electra-base-discriminator/pytorch_model.bin",
|
||||
"google/electra-large-discriminator": "https://cdn.huggingface.co/google/electra-large-discriminator/pytorch_model.bin",
|
||||
}
|
||||
ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/electra-small-generator",
|
||||
"google/electra-base-generator",
|
||||
"google/electra-large-generator",
|
||||
"google/electra-small-discriminator",
|
||||
"google/electra-base-discriminator",
|
||||
"google/electra-large-discriminator",
|
||||
# See all ELECTRA models at https://huggingface.co/models?filter=electra
|
||||
]
|
||||
|
||||
|
||||
def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_or_generator="discriminator"):
|
||||
|
@ -160,7 +161,6 @@ class ElectraPreTrainedModel(BertPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = ElectraConfig
|
||||
pretrained_model_archive_map = ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_electra
|
||||
base_model_prefix = "electra"
|
||||
|
||||
|
|
|
@ -35,12 +35,13 @@ from .modeling_xlm import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"flaubert-small-cased": "https://cdn.huggingface.co/flaubert/flaubert_small_cased/pytorch_model.bin",
|
||||
"flaubert-base-uncased": "https://cdn.huggingface.co/flaubert/flaubert_base_uncased/pytorch_model.bin",
|
||||
"flaubert-base-cased": "https://cdn.huggingface.co/flaubert/flaubert_base_cased/pytorch_model.bin",
|
||||
"flaubert-large-cased": "https://cdn.huggingface.co/flaubert/flaubert_large_cased/pytorch_model.bin",
|
||||
}
|
||||
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"flaubert/flaubert_small_cased",
|
||||
"flaubert/flaubert_base_uncased",
|
||||
"flaubert/flaubert_base_cased",
|
||||
"flaubert/flaubert_large_cased",
|
||||
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
|
||||
]
|
||||
|
||||
|
||||
FLAUBERT_START_DOCSTRING = r"""
|
||||
|
@ -109,7 +110,6 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
|
|||
class FlaubertModel(XLMModel):
|
||||
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config): # , dico, is_encoder, with_output):
|
||||
super().__init__(config)
|
||||
|
@ -304,7 +304,6 @@ class FlaubertWithLMHeadModel(XLMWithLMHeadModel):
|
|||
"""
|
||||
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -324,7 +323,6 @@ class FlaubertForSequenceClassification(XLMForSequenceClassification):
|
|||
"""
|
||||
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -344,7 +342,6 @@ class FlaubertForQuestionAnsweringSimple(XLMForQuestionAnsweringSimple):
|
|||
"""
|
||||
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -364,7 +361,6 @@ class FlaubertForQuestionAnswering(XLMForQuestionAnswering):
|
|||
"""
|
||||
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
|
|
@ -31,13 +31,14 @@ from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"gpt2": "https://cdn.huggingface.co/gpt2-pytorch_model.bin",
|
||||
"gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-pytorch_model.bin",
|
||||
"gpt2-large": "https://cdn.huggingface.co/gpt2-large-pytorch_model.bin",
|
||||
"gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-pytorch_model.bin",
|
||||
"distilgpt2": "https://cdn.huggingface.co/distilgpt2-pytorch_model.bin",
|
||||
}
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"gpt2",
|
||||
"gpt2-medium",
|
||||
"gpt2-large",
|
||||
"gpt2-xl",
|
||||
"distilgpt2",
|
||||
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
||||
]
|
||||
|
||||
|
||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
|
@ -251,7 +252,6 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = GPT2Config
|
||||
pretrained_model_archive_map = GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_gpt2
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -30,13 +30,14 @@ from .modeling_roberta import RobertaLMHead, RobertaModel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"allenai/longformer-base-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-finetuned-triviaqa/pytorch_model.bin",
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-base-4096-extra.pos.embd.only/pytorch_model.bin",
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only": "https://s3.amazonaws.com/models.huggingface.co/bert/allenai/longformer-large-4096-extra.pos.embd.only/pytorch_model.bin",
|
||||
}
|
||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"allenai/longformer-base-4096",
|
||||
"allenai/longformer-large-4096",
|
||||
"allenai/longformer-large-4096-finetuned-triviaqa",
|
||||
"allenai/longformer-base-4096-extra.pos.embd.only",
|
||||
"allenai/longformer-large-4096-extra.pos.embd.only",
|
||||
# See all Longformer models at https://huggingface.co/models?filter=longformer
|
||||
]
|
||||
|
||||
|
||||
def _get_question_end_index(input_ids, sep_token_id):
|
||||
|
@ -513,7 +514,6 @@ class LongformerModel(RobertaModel):
|
|||
"""
|
||||
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -685,7 +685,6 @@ class LongformerModel(RobertaModel):
|
|||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
||||
class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -776,7 +775,6 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
|||
)
|
||||
class LongformerForSequenceClassification(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -893,7 +891,6 @@ class LongformerClassificationHead(nn.Module):
|
|||
)
|
||||
class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -1018,7 +1015,6 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
|||
)
|
||||
class LongformerForTokenClassification(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -1119,7 +1115,6 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
|||
)
|
||||
class LongformerForMultipleChoice(BertPreTrainedModel):
|
||||
config_class = LongformerConfig
|
||||
pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "longformer"
|
||||
|
||||
def __init__(self, config):
|
||||
|
|
|
@ -18,6 +18,11 @@
|
|||
from transformers.modeling_bart import BartForConditionalGeneration
|
||||
|
||||
|
||||
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Marian models at https://huggingface.co/models?search=Helsinki-NLP
|
||||
]
|
||||
|
||||
|
||||
class MarianMTModel(BartForConditionalGeneration):
|
||||
r"""
|
||||
Pytorch version of marian-nmt's transformer.h (c++). Designed for the OPUS-NMT translation checkpoints.
|
||||
|
@ -41,8 +46,6 @@ class MarianMTModel(BartForConditionalGeneration):
|
|||
|
||||
"""
|
||||
|
||||
pretrained_model_archive_map = {} # see https://huggingface.co/models?search=Helsinki-NLP
|
||||
|
||||
def prepare_logits_for_generation(self, logits, cur_len, max_length):
|
||||
logits[:, self.config.pad_token_id] = float("-inf")
|
||||
if cur_len == max_length - 1 and self.config.eos_token_id is not None:
|
||||
|
|
|
@ -33,7 +33,10 @@ from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://cdn.huggingface.co/openai-gpt-pytorch_model.bin"}
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"openai-gpt",
|
||||
# See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
|
||||
]
|
||||
|
||||
|
||||
def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
||||
|
@ -252,7 +255,6 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = OpenAIGPTConfig
|
||||
pretrained_model_archive_map = OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -35,10 +35,11 @@ from .modeling_utils import PreTrainedModel, apply_chunking_to_forward
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"google/reformer-crime-and-punishment": "https://cdn.huggingface.co/google/reformer-crime-and-punishment/pytorch_model.bin",
|
||||
"google/reformer-enwik8": "https://cdn.huggingface.co/google/reformer-enwik8/pytorch_model.bin",
|
||||
}
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/reformer-crime-and-punishment",
|
||||
"google/reformer-enwik8",
|
||||
# See all Reformer models at https://huggingface.co/models?filter=reformer
|
||||
]
|
||||
|
||||
|
||||
def mish(x):
|
||||
|
@ -1373,7 +1374,6 @@ class ReformerPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = ReformerConfig
|
||||
pretrained_model_archive_map = REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "reformer"
|
||||
|
||||
@property
|
||||
|
|
|
@ -30,14 +30,15 @@ from .modeling_utils import create_position_ids_from_input_ids
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"roberta-base": "https://cdn.huggingface.co/roberta-base-pytorch_model.bin",
|
||||
"roberta-large": "https://cdn.huggingface.co/roberta-large-pytorch_model.bin",
|
||||
"roberta-large-mnli": "https://cdn.huggingface.co/roberta-large-mnli-pytorch_model.bin",
|
||||
"distilroberta-base": "https://cdn.huggingface.co/distilroberta-base-pytorch_model.bin",
|
||||
"roberta-base-openai-detector": "https://cdn.huggingface.co/roberta-base-openai-detector-pytorch_model.bin",
|
||||
"roberta-large-openai-detector": "https://cdn.huggingface.co/roberta-large-openai-detector-pytorch_model.bin",
|
||||
}
|
||||
ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"roberta-base",
|
||||
"roberta-large",
|
||||
"roberta-large-mnli",
|
||||
"distilroberta-base",
|
||||
"roberta-base-openai-detector",
|
||||
"roberta-large-openai-detector",
|
||||
# See all RoBERTa models at https://huggingface.co/models?filter=roberta
|
||||
]
|
||||
|
||||
|
||||
class RobertaEmbeddings(BertEmbeddings):
|
||||
|
@ -142,7 +143,6 @@ class RobertaModel(BertModel):
|
|||
"""
|
||||
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -161,7 +161,6 @@ class RobertaModel(BertModel):
|
|||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
class RobertaForMaskedLM(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -276,7 +275,6 @@ class RobertaLMHead(nn.Module):
|
|||
)
|
||||
class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -367,7 +365,6 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
|
|||
)
|
||||
class RobertaForMultipleChoice(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -466,7 +463,6 @@ class RobertaForMultipleChoice(BertPreTrainedModel):
|
|||
)
|
||||
class RobertaForTokenClassification(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -586,7 +582,6 @@ class RobertaClassificationHead(nn.Module):
|
|||
)
|
||||
class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
|
|
|
@ -36,13 +36,14 @@ logger = logging.getLogger(__name__)
|
|||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
####################################################
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"t5-small": "https://cdn.huggingface.co/t5-small-pytorch_model.bin",
|
||||
"t5-base": "https://cdn.huggingface.co/t5-base-pytorch_model.bin",
|
||||
"t5-large": "https://cdn.huggingface.co/t5-large-pytorch_model.bin",
|
||||
"t5-3b": "https://cdn.huggingface.co/t5-3b-pytorch_model.bin",
|
||||
"t5-11b": "https://cdn.huggingface.co/t5-11b-pytorch_model.bin",
|
||||
}
|
||||
T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
# See all T5 models at https://huggingface.co/models?filter=t5
|
||||
]
|
||||
|
||||
|
||||
####################################################
|
||||
|
@ -555,7 +556,6 @@ class T5PreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = T5Config
|
||||
pretrained_model_archive_map = T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_t5
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -29,16 +29,17 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"albert-base-v1": "https://cdn.huggingface.co/albert-base-v1-with-prefix-tf_model.h5",
|
||||
"albert-large-v1": "https://cdn.huggingface.co/albert-large-v1-with-prefix-tf_model.h5",
|
||||
"albert-xlarge-v1": "https://cdn.huggingface.co/albert-xlarge-v1-with-prefix-tf_model.h5",
|
||||
"albert-xxlarge-v1": "https://cdn.huggingface.co/albert-xxlarge-v1-with-prefix-tf_model.h5",
|
||||
"albert-base-v2": "https://cdn.huggingface.co/albert-base-v2-with-prefix-tf_model.h5",
|
||||
"albert-large-v2": "https://cdn.huggingface.co/albert-large-v2-with-prefix-tf_model.h5",
|
||||
"albert-xlarge-v2": "https://cdn.huggingface.co/albert-xlarge-v2-with-prefix-tf_model.h5",
|
||||
"albert-xxlarge-v2": "https://cdn.huggingface.co/albert-xxlarge-v2-with-prefix-tf_model.h5",
|
||||
}
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"albert-base-v1",
|
||||
"albert-large-v1",
|
||||
"albert-xlarge-v1",
|
||||
"albert-xxlarge-v1",
|
||||
"albert-base-v2",
|
||||
"albert-large-v2",
|
||||
"albert-xlarge-v2",
|
||||
"albert-xxlarge-v2",
|
||||
# See all ALBERT models at https://huggingface.co/models?filter=albert
|
||||
]
|
||||
|
||||
|
||||
class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
||||
|
@ -440,7 +441,6 @@ class TFAlbertPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = AlbertConfig
|
||||
pretrained_model_archive_map = TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "albert"
|
||||
|
||||
|
||||
|
|
|
@ -34,7 +34,6 @@ from .configuration_auto import (
|
|||
)
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .modeling_tf_albert import (
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFAlbertForMaskedLM,
|
||||
TFAlbertForMultipleChoice,
|
||||
TFAlbertForPreTraining,
|
||||
|
@ -43,7 +42,6 @@ from .modeling_tf_albert import (
|
|||
TFAlbertModel,
|
||||
)
|
||||
from .modeling_tf_bert import (
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFBertForMaskedLM,
|
||||
TFBertForMultipleChoice,
|
||||
TFBertForPreTraining,
|
||||
|
@ -52,40 +50,32 @@ from .modeling_tf_bert import (
|
|||
TFBertForTokenClassification,
|
||||
TFBertModel,
|
||||
)
|
||||
from .modeling_tf_ctrl import TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, TFCTRLLMHeadModel, TFCTRLModel
|
||||
from .modeling_tf_ctrl import TFCTRLLMHeadModel, TFCTRLModel
|
||||
from .modeling_tf_distilbert import (
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFDistilBertForMaskedLM,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertModel,
|
||||
)
|
||||
from .modeling_tf_gpt2 import TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, TFGPT2LMHeadModel, TFGPT2Model
|
||||
from .modeling_tf_openai import TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from .modeling_tf_gpt2 import TFGPT2LMHeadModel, TFGPT2Model
|
||||
from .modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
|
||||
from .modeling_tf_roberta import (
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFRobertaForMaskedLM,
|
||||
TFRobertaForQuestionAnswering,
|
||||
TFRobertaForSequenceClassification,
|
||||
TFRobertaForTokenClassification,
|
||||
TFRobertaModel,
|
||||
)
|
||||
from .modeling_tf_t5 import TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP, TFT5ForConditionalGeneration, TFT5Model
|
||||
from .modeling_tf_transfo_xl import (
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TFTransfoXLModel,
|
||||
)
|
||||
from .modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
|
||||
from .modeling_tf_transfo_xl import TFTransfoXLLMHeadModel, TFTransfoXLModel
|
||||
from .modeling_tf_xlm import (
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFXLMForQuestionAnsweringSimple,
|
||||
TFXLMForSequenceClassification,
|
||||
TFXLMModel,
|
||||
TFXLMWithLMHeadModel,
|
||||
)
|
||||
from .modeling_tf_xlnet import (
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TFXLNetForQuestionAnsweringSimple,
|
||||
TFXLNetForSequenceClassification,
|
||||
TFXLNetForTokenClassification,
|
||||
|
@ -97,24 +87,6 @@ from .modeling_tf_xlnet import (
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TF_ALL_PRETRAINED_MODEL_ARCHIVE_MAP = dict(
|
||||
(key, value)
|
||||
for pretrained_map in [
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
]
|
||||
for key, value, in pretrained_map.items()
|
||||
)
|
||||
|
||||
TF_MODEL_MAPPING = OrderedDict(
|
||||
[
|
||||
(T5Config, TFT5Model),
|
||||
|
@ -208,20 +180,17 @@ class TFAutoModel(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The base model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5Model (T5 model)
|
||||
- contains `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- contains `bert`: TFBertModel (Bert model)
|
||||
- contains `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetModel (XLNet model)
|
||||
- contains `xlm`: TFXLMModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLModel (CTRL model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: TFT5Model (T5 model)
|
||||
- `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- `bert`: TFBertModel (Bert model)
|
||||
- `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
|
||||
- `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- `xlnet`: TFXLNetModel (XLNet model)
|
||||
- `xlm`: TFXLMModel (XLM model)
|
||||
- `ctrl`: TFCTRLModel (CTRL model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
@ -276,17 +245,18 @@ class TFAutoModel(object):
|
|||
r""" Instantiates one of the base model classes of the library
|
||||
from a pre-trained model configuration.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5Model (T5 model)
|
||||
- contains `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- contains `bert`: TFTFBertModel (Bert model)
|
||||
- contains `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
|
||||
- contains `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetModel (XLNet model)
|
||||
- contains `ctrl`: TFCTRLModel (CTRL model)
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: TFT5Model (T5 model)
|
||||
- `distilbert`: TFDistilBertModel (DistilBERT model)
|
||||
- `roberta`: TFRobertaModel (RoBERTa model)
|
||||
- `bert`: TFTFBertModel (Bert model)
|
||||
- `openai-gpt`: TFOpenAIGPTModel (OpenAI GPT model)
|
||||
- `gpt2`: TFGPT2Model (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TFTransfoXLModel (Transformer-XL model)
|
||||
- `xlnet`: TFXLNetModel (XLNet model)
|
||||
- `ctrl`: TFCTRLModel (CTRL model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
@ -424,21 +394,18 @@ class TFAutoModelForPreTraining(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: :class:`~transformers.TFT5ModelWithLMHead` (T5 model)
|
||||
- contains `distilbert`: :class:`~transformers.TFDistilBertForMaskedLM` (DistilBERT model)
|
||||
- contains `albert`: :class:`~transformers.TFAlbertForPreTraining` (ALBERT model)
|
||||
- contains `roberta`: :class:`~transformers.TFRobertaForMaskedLM` (RoBERTa model)
|
||||
- contains `bert`: :class:`~transformers.TFBertForPreTraining` (Bert model)
|
||||
- contains `openai-gpt`: :class:`~transformers.TFOpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- contains `gpt2`: :class:`~transformers.TFGPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: :class:`~transformers.TFTransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- contains `xlnet`: :class:`~transformers.TFXLNetLMHeadModel` (XLNet model)
|
||||
- contains `xlm`: :class:`~transformers.TFXLMWithLMHeadModel` (XLM model)
|
||||
- contains `ctrl`: :class:`~transformers.TFCTRLLMHeadModel` (Salesforce CTRL model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: :class:`~transformers.TFT5ModelWithLMHead` (T5 model)
|
||||
- `distilbert`: :class:`~transformers.TFDistilBertForMaskedLM` (DistilBERT model)
|
||||
- `albert`: :class:`~transformers.TFAlbertForPreTraining` (ALBERT model)
|
||||
- `roberta`: :class:`~transformers.TFRobertaForMaskedLM` (RoBERTa model)
|
||||
- `bert`: :class:`~transformers.TFBertForPreTraining` (Bert model)
|
||||
- `openai-gpt`: :class:`~transformers.TFOpenAIGPTLMHeadModel` (OpenAI GPT model)
|
||||
- `gpt2`: :class:`~transformers.TFGPT2LMHeadModel` (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: :class:`~transformers.TFTransfoXLLMHeadModel` (Transformer-XL model)
|
||||
- `xlnet`: :class:`~transformers.TFXLNetLMHeadModel` (XLNet model)
|
||||
- `xlm`: :class:`~transformers.TFXLMWithLMHeadModel` (XLM model)
|
||||
- `ctrl`: :class:`~transformers.TFCTRLLMHeadModel` (Salesforce CTRL model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -525,20 +492,17 @@ class TFAutoModelWithLMHead(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5ForConditionalGeneration (T5 model)
|
||||
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: TFBertForMaskedLM (Bert model)
|
||||
- contains `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
|
||||
- contains `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: TFT5ForConditionalGeneration (T5 model)
|
||||
- `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- `bert`: TFBertForMaskedLM (Bert model)
|
||||
- `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
|
||||
- `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
@ -595,20 +559,17 @@ class TFAutoModelWithLMHead(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: TFT5ForConditionalGeneration (T5 model)
|
||||
- contains `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- contains `bert`: TFBertForMaskedLM (Bert model)
|
||||
- contains `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
|
||||
- contains `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- contains `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- contains `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- contains `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: TFT5ForConditionalGeneration (T5 model)
|
||||
- `distilbert`: TFDistilBertForMaskedLM (DistilBERT model)
|
||||
- `roberta`: TFRobertaForMaskedLM (RoBERTa model)
|
||||
- `bert`: TFBertForMaskedLM (Bert model)
|
||||
- `openai-gpt`: TFOpenAIGPTLMHeadModel (OpenAI GPT model)
|
||||
- `gpt2`: TFGPT2LMHeadModel (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TFTransfoXLLMHeadModel (Transformer-XL model)
|
||||
- `xlnet`: TFXLNetLMHeadModel (XLNet model)
|
||||
- `xlm`: TFXLMWithLMHeadModel (XLM model)
|
||||
- `ctrl`: TFCTRLLMHeadModel (CTRL model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
@ -694,12 +655,9 @@ class TFAutoModelForMultipleChoice:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `albert`: TFAlbertForMultipleChoice (Albert model)
|
||||
- contains `bert`: TFBertForMultipleChoice (Bert model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `albert`: TFAlbertForMultipleChoice (Albert model)
|
||||
- `bert`: TFBertForMultipleChoice (Bert model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
@ -751,12 +709,9 @@ class TFAutoModelForMultipleChoice:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `albert`: TFRobertaForMultiple (Albert model)
|
||||
- contains `bert`: TFBertForMultipleChoice (Bert model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `albert`: TFRobertaForMultiple (Albert model)
|
||||
- `bert`: TFBertForMultipleChoice (Bert model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -847,15 +802,12 @@ class TFAutoModelForSequenceClassification(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
|
||||
- contains `bert`: TFBertForSequenceClassification (Bert model)
|
||||
- contains `xlnet`: TFXLNetForSequenceClassification (XLNet model)
|
||||
- contains `xlm`: TFXLMForSequenceClassification (XLM model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
|
||||
- `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
|
||||
- `bert`: TFBertForSequenceClassification (Bert model)
|
||||
- `xlnet`: TFXLNetForSequenceClassification (XLNet model)
|
||||
- `xlm`: TFXLMForSequenceClassification (XLM model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
@ -910,15 +862,12 @@ class TFAutoModelForSequenceClassification(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
|
||||
- contains `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
|
||||
- contains `bert`: TFBertForSequenceClassification (Bert model)
|
||||
- contains `xlnet`: TFXLNetForSequenceClassification (XLNet model)
|
||||
- contains `xlm`: TFXLMForSequenceClassification (XLM model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: TFDistilBertForSequenceClassification (DistilBERT model)
|
||||
- `roberta`: TFRobertaForSequenceClassification (RoBERTa model)
|
||||
- `bert`: TFBertForSequenceClassification (Bert model)
|
||||
- `xlnet`: TFXLNetForSequenceClassification (XLNet model)
|
||||
- `xlm`: TFXLMForSequenceClassification (XLM model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -1009,16 +958,13 @@ class TFAutoModelForQuestionAnswering(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- contains `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- contains `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- contains `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throws an error).
|
||||
"""
|
||||
|
@ -1074,16 +1020,13 @@ class TFAutoModelForQuestionAnswering(object):
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- contains `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- contains `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- contains `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- contains `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- contains `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `distilbert`: TFDistilBertForQuestionAnswering (DistilBERT model)
|
||||
- `albert`: TFAlbertForQuestionAnswering (ALBERT model)
|
||||
- `roberta`: TFRobertaForQuestionAnswering (RoBERTa model)
|
||||
- `bert`: TFBertForQuestionAnswering (Bert model)
|
||||
- `xlnet`: TFXLNetForQuestionAnswering (XLNet model)
|
||||
- `xlm`: TFXLMForQuestionAnswering (XLM model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
@ -1215,14 +1158,11 @@ class TFAutoModelForTokenClassification:
|
|||
|
||||
The `from_pretrained()` method takes care of returning the correct model class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The model class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `bert`: BertForTokenClassification (Bert model)
|
||||
- contains `xlnet`: XLNetForTokenClassification (XLNet model)
|
||||
- contains `distilbert`: DistilBertForTokenClassification (DistilBert model)
|
||||
- contains `roberta`: RobertaForTokenClassification (Roberta model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `bert`: BertForTokenClassification (Bert model)
|
||||
- `xlnet`: XLNetForTokenClassification (XLNet model)
|
||||
- `distilbert`: DistilBertForTokenClassification (DistilBert model)
|
||||
- `roberta`: RobertaForTokenClassification (Roberta model)
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
|
||||
To train the model, you should first set it back in training mode with `model.train()`
|
||||
|
|
|
@ -30,28 +30,29 @@ from .tokenization_utils import BatchEncoding
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"bert-base-uncased": "https://cdn.huggingface.co/bert-base-uncased-tf_model.h5",
|
||||
"bert-large-uncased": "https://cdn.huggingface.co/bert-large-uncased-tf_model.h5",
|
||||
"bert-base-cased": "https://cdn.huggingface.co/bert-base-cased-tf_model.h5",
|
||||
"bert-large-cased": "https://cdn.huggingface.co/bert-large-cased-tf_model.h5",
|
||||
"bert-base-multilingual-uncased": "https://cdn.huggingface.co/bert-base-multilingual-uncased-tf_model.h5",
|
||||
"bert-base-multilingual-cased": "https://cdn.huggingface.co/bert-base-multilingual-cased-tf_model.h5",
|
||||
"bert-base-chinese": "https://cdn.huggingface.co/bert-base-chinese-tf_model.h5",
|
||||
"bert-base-german-cased": "https://cdn.huggingface.co/bert-base-german-cased-tf_model.h5",
|
||||
"bert-large-uncased-whole-word-masking": "https://cdn.huggingface.co/bert-large-uncased-whole-word-masking-tf_model.h5",
|
||||
"bert-large-cased-whole-word-masking": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-tf_model.h5",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://cdn.huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad-tf_model.h5",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad": "https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-tf_model.h5",
|
||||
"bert-base-cased-finetuned-mrpc": "https://cdn.huggingface.co/bert-base-cased-finetuned-mrpc-tf_model.h5",
|
||||
"bert-base-japanese": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese/tf_model.h5",
|
||||
"bert-base-japanese-whole-word-masking": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/tf_model.h5",
|
||||
"bert-base-japanese-char": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-char/tf_model.h5",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://cdn.huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/tf_model.h5",
|
||||
"bert-base-finnish-cased-v1": "https://cdn.huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/tf_model.h5",
|
||||
"bert-base-finnish-uncased-v1": "https://cdn.huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/tf_model.h5",
|
||||
"bert-base-dutch-cased": "https://cdn.huggingface.co/wietsedv/bert-base-dutch-cased/tf_model.h5",
|
||||
}
|
||||
TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"bert-base-uncased",
|
||||
"bert-large-uncased",
|
||||
"bert-base-cased",
|
||||
"bert-large-cased",
|
||||
"bert-base-multilingual-uncased",
|
||||
"bert-base-multilingual-cased",
|
||||
"bert-base-chinese",
|
||||
"bert-base-german-cased",
|
||||
"bert-large-uncased-whole-word-masking",
|
||||
"bert-large-cased-whole-word-masking",
|
||||
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
||||
"bert-large-cased-whole-word-masking-finetuned-squad",
|
||||
"bert-base-cased-finetuned-mrpc",
|
||||
"cl-tohoku/bert-base-japanese",
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
||||
"cl-tohoku/bert-base-japanese-char",
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
||||
"TurkuNLP/bert-base-finnish-cased-v1",
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1",
|
||||
"wietsedv/bert-base-dutch-cased",
|
||||
# See all BERT models at https://huggingface.co/models?filter=bert
|
||||
]
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
@ -585,7 +586,6 @@ class TFBertPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = BertConfig
|
||||
pretrained_model_archive_map = TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
||||
|
|
|
@ -30,7 +30,9 @@ from .modeling_tf_roberta import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all CamemBERT models at https://huggingface.co/models?filter=camembert
|
||||
]
|
||||
|
||||
|
||||
CAMEMBERT_START_DOCSTRING = r"""
|
||||
|
@ -72,7 +74,6 @@ class TFCamembertModel(TFRobertaModel):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -85,7 +86,6 @@ class TFCamembertForMaskedLM(TFRobertaForMaskedLM):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -100,7 +100,6 @@ class TFCamembertForSequenceClassification(TFRobertaForSequenceClassification):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -115,4 +114,3 @@ class TFCamembertForTokenClassification(TFRobertaForTokenClassification):
|
|||
"""
|
||||
|
||||
config_class = CamembertConfig
|
||||
pretrained_model_archive_map = TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
|
|
@ -29,7 +29,10 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP = {"ctrl": "https://cdn.huggingface.co/ctrl-tf_model.h5"}
|
||||
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"ctrl"
|
||||
# See all CTRL models at https://huggingface.co/models?filter=ctrl
|
||||
]
|
||||
|
||||
|
||||
def angle_defn(pos, i, d_model_size):
|
||||
|
@ -379,7 +382,6 @@ class TFCTRLPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = CTRLConfig
|
||||
pretrained_model_archive_map = TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -31,14 +31,15 @@ from .tokenization_utils import BatchEncoding
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"distilbert-base-uncased": "https://cdn.huggingface.co/distilbert-base-uncased-tf_model.h5",
|
||||
"distilbert-base-uncased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-uncased-distilled-squad-tf_model.h5",
|
||||
"distilbert-base-cased": "https://cdn.huggingface.co/distilbert-base-cased-tf_model.h5",
|
||||
"distilbert-base-cased-distilled-squad": "https://cdn.huggingface.co/distilbert-base-cased-distilled-squad-tf_model.h5",
|
||||
"distilbert-base-multilingual-cased": "https://cdn.huggingface.co/distilbert-base-multilingual-cased-tf_model.h5",
|
||||
"distilbert-base-uncased-finetuned-sst-2-english": "https://cdn.huggingface.co/distilbert-base-uncased-finetuned-sst-2-english-tf_model.h5",
|
||||
}
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"distilbert-base-uncased",
|
||||
"distilbert-base-uncased-distilled-squad",
|
||||
"distilbert-base-cased",
|
||||
"distilbert-base-cased-distilled-squad",
|
||||
"distilbert-base-multilingual-cased",
|
||||
"distilbert-base-uncased-finetuned-sst-2-english",
|
||||
# See all DistilBERT models at https://huggingface.co/models?filter=distilbert
|
||||
]
|
||||
|
||||
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
|
@ -467,7 +468,6 @@ class TFDistilBertPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = DistilBertConfig
|
||||
pretrained_model_archive_map = TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "distilbert"
|
||||
|
||||
|
||||
|
|
|
@ -13,14 +13,15 @@ from .tokenization_utils import BatchEncoding
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"google/electra-small-generator": "https://cdn.huggingface.co/google/electra-small-generator/tf_model.h5",
|
||||
"google/electra-base-generator": "https://cdn.huggingface.co/google/electra-base-generator/tf_model.h5",
|
||||
"google/electra-large-generator": "https://cdn.huggingface.co/google/electra-large-generator/tf_model.h5",
|
||||
"google/electra-small-discriminator": "https://cdn.huggingface.co/google/electra-small-discriminator/tf_model.h5",
|
||||
"google/electra-base-discriminator": "https://cdn.huggingface.co/google/electra-base-discriminator/tf_model.h5",
|
||||
"google/electra-large-discriminator": "https://cdn.huggingface.co/google/electra-large-discriminator/tf_model.h5",
|
||||
}
|
||||
TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"google/electra-small-generator",
|
||||
"google/electra-base-generator",
|
||||
"google/electra-large-generator",
|
||||
"google/electra-small-discriminator",
|
||||
"google/electra-base-discriminator",
|
||||
"google/electra-large-discriminator",
|
||||
# See all ELECTRA models at https://huggingface.co/models?filter=electra
|
||||
]
|
||||
|
||||
|
||||
class TFElectraEmbeddings(tf.keras.layers.Layer):
|
||||
|
@ -160,7 +161,6 @@ class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
|
|||
class TFElectraPreTrainedModel(TFBertPreTrainedModel):
|
||||
|
||||
config_class = ElectraConfig
|
||||
pretrained_model_archive_map = TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "electra"
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask, input_shape):
|
||||
|
|
|
@ -35,7 +35,9 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all Flaubert models at https://huggingface.co/models?filter=flaubert
|
||||
]
|
||||
|
||||
FLAUBERT_START_DOCSTRING = r"""
|
||||
|
||||
|
@ -104,7 +106,6 @@ FLAUBERT_INPUTS_DOCSTRING = r"""
|
|||
)
|
||||
class TFFlaubertModel(TFXLMModel):
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -309,7 +310,6 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||
)
|
||||
class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -323,7 +323,6 @@ class TFFlaubertWithLMHeadModel(TFXLMWithLMHeadModel):
|
|||
)
|
||||
class TFFlaubertForSequenceClassification(TFXLMForSequenceClassification):
|
||||
config_class = FlaubertConfig
|
||||
pretrained_model_archive_map = TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
|
|
@ -37,13 +37,14 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"gpt2": "https://cdn.huggingface.co/gpt2-tf_model.h5",
|
||||
"gpt2-medium": "https://cdn.huggingface.co/gpt2-medium-tf_model.h5",
|
||||
"gpt2-large": "https://cdn.huggingface.co/gpt2-large-tf_model.h5",
|
||||
"gpt2-xl": "https://cdn.huggingface.co/gpt2-xl-tf_model.h5",
|
||||
"distilgpt2": "https://cdn.huggingface.co/distilgpt2-tf_model.h5",
|
||||
}
|
||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"gpt2",
|
||||
"gpt2-medium",
|
||||
"gpt2-large",
|
||||
"gpt2-xl",
|
||||
"distilgpt2",
|
||||
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
||||
]
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
@ -389,7 +390,6 @@ class TFGPT2PreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = GPT2Config
|
||||
pretrained_model_archive_map = TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -36,7 +36,10 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP = {"openai-gpt": "https://cdn.huggingface.co/openai-gpt-tf_model.h5"}
|
||||
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"openai-gpt",
|
||||
# See all OpenAI GPT models at https://huggingface.co/models?filter=openai-gpt
|
||||
]
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
@ -349,7 +352,6 @@ class TFOpenAIGPTPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = OpenAIGPTConfig
|
||||
pretrained_model_archive_map = TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -28,12 +28,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"roberta-base": "https://cdn.huggingface.co/roberta-base-tf_model.h5",
|
||||
"roberta-large": "https://cdn.huggingface.co/roberta-large-tf_model.h5",
|
||||
"roberta-large-mnli": "https://cdn.huggingface.co/roberta-large-mnli-tf_model.h5",
|
||||
"distilroberta-base": "https://cdn.huggingface.co/distilroberta-base-tf_model.h5",
|
||||
}
|
||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"roberta-base",
|
||||
"roberta-large",
|
||||
"roberta-large-mnli",
|
||||
"distilroberta-base",
|
||||
# See all RoBERTa models at https://huggingface.co/models?filter=roberta
|
||||
]
|
||||
|
||||
|
||||
class TFRobertaEmbeddings(TFBertEmbeddings):
|
||||
|
@ -100,7 +101,6 @@ class TFRobertaPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
|
||||
|
|
|
@ -30,13 +30,14 @@ from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, shape_list
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"t5-small": "https://cdn.huggingface.co/t5-small-tf_model.h5",
|
||||
"t5-base": "https://cdn.huggingface.co/t5-base-tf_model.h5",
|
||||
"t5-large": "https://cdn.huggingface.co/t5-large-tf_model.h5",
|
||||
"t5-3b": "https://cdn.huggingface.co/t5-3b-tf_model.h5",
|
||||
"t5-11b": "https://cdn.huggingface.co/t5-11b-tf_model.h5",
|
||||
}
|
||||
TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"t5-small",
|
||||
"t5-base",
|
||||
"t5-large",
|
||||
"t5-3b",
|
||||
"t5-11b",
|
||||
# See all T5 models at https://huggingface.co/models?filter=t5
|
||||
]
|
||||
|
||||
####################################################
|
||||
# TF 2.0 Models are constructed using Keras imperative API by sub-classing
|
||||
|
@ -720,7 +721,6 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = T5Config
|
||||
pretrained_model_archive_map = TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
@property
|
||||
|
|
|
@ -30,9 +30,10 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-tf_model.h5",
|
||||
}
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"transfo-xl-wt103",
|
||||
# See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
|
||||
]
|
||||
|
||||
|
||||
class TFPositionalEmbedding(tf.keras.layers.Layer):
|
||||
|
@ -630,7 +631,6 @@ class TFTransfoXLPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = TransfoXLConfig
|
||||
pretrained_model_archive_map = TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -112,7 +112,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
||||
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
||||
|
||||
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
|
||||
|
@ -122,7 +121,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
|
||||
@property
|
||||
|
@ -338,9 +336,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
||||
# Load from a TF 2.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
||||
|
@ -364,8 +360,8 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
use_cdn=use_cdn,
|
||||
)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
|
@ -373,20 +369,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
|||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
)
|
||||
except EnvironmentError as e:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
|
||||
else:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(cls.pretrained_model_archive_map.keys()),
|
||||
archive_file,
|
||||
)
|
||||
)
|
||||
raise e
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info("loading weights file {}".format(archive_file))
|
||||
else:
|
||||
|
|
|
@ -31,18 +31,19 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-tf_model.h5",
|
||||
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-tf_model.h5",
|
||||
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-tf_model.h5",
|
||||
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-tf_model.h5",
|
||||
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-tf_model.h5",
|
||||
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-tf_model.h5",
|
||||
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-tf_model.h5",
|
||||
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-tf_model.h5",
|
||||
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-tf_model.h5",
|
||||
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-tf_model.h5",
|
||||
}
|
||||
TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlm-mlm-en-2048",
|
||||
"xlm-mlm-ende-1024",
|
||||
"xlm-mlm-enfr-1024",
|
||||
"xlm-mlm-enro-1024",
|
||||
"xlm-mlm-tlm-xnli15-1024",
|
||||
"xlm-mlm-xnli15-1024",
|
||||
"xlm-clm-enfr-1024",
|
||||
"xlm-clm-ende-1024",
|
||||
"xlm-mlm-17-1280",
|
||||
"xlm-mlm-100-1280",
|
||||
# See all XLM models at https://huggingface.co/models?filter=xlm
|
||||
]
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
|
@ -470,7 +471,6 @@ class TFXLMPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XLMConfig
|
||||
pretrained_model_archive_map = TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
@property
|
||||
|
|
|
@ -30,7 +30,9 @@ from .modeling_tf_roberta import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {}
|
||||
TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
|
||||
]
|
||||
|
||||
|
||||
XLM_ROBERTA_START_DOCSTRING = r"""
|
||||
|
@ -72,7 +74,6 @@ class TFXLMRobertaModel(TFRobertaModel):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -85,7 +86,6 @@ class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -100,7 +100,6 @@ class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -115,4 +114,3 @@ class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
|
|
@ -37,10 +37,11 @@ from .tokenization_utils import BatchEncoding
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-tf_model.h5",
|
||||
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-tf_model.h5",
|
||||
}
|
||||
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlnet-base-cased",
|
||||
"xlnet-large-cased",
|
||||
# See all XLNet models at https://huggingface.co/models?filter=xlnet
|
||||
]
|
||||
|
||||
|
||||
def gelu(x):
|
||||
|
@ -701,7 +702,6 @@ class TFXLNetPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XLNetConfig
|
||||
pretrained_model_archive_map = TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -33,9 +33,10 @@ from .modeling_utils import PreTrainedModel
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"transfo-xl-wt103": "https://cdn.huggingface.co/transfo-xl-wt103-pytorch_model.bin",
|
||||
}
|
||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"transfo-xl-wt103",
|
||||
# See all Transformer XL models at https://huggingface.co/models?filter=transfo-xl
|
||||
]
|
||||
|
||||
|
||||
def build_tf_to_pytorch_map(model, config):
|
||||
|
@ -453,7 +454,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = TransfoXLConfig
|
||||
pretrained_model_archive_map = TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_transfo_xl
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -110,6 +110,9 @@ class ModuleUtilsMixin:
|
|||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
Get torch.device from module, assuming that the whole module has one device.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
except StopIteration:
|
||||
|
@ -125,6 +128,9 @@ class ModuleUtilsMixin:
|
|||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
"""
|
||||
Get torch.dtype from module, assuming that the whole module has one dtype.
|
||||
"""
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
|
@ -249,7 +255,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
|
||||
Class attributes (overridden by derived classes):
|
||||
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
||||
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
||||
|
||||
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
|
||||
|
@ -259,7 +264,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
||||
"""
|
||||
config_class = None
|
||||
pretrained_model_archive_map = {}
|
||||
base_model_prefix = ""
|
||||
|
||||
@property
|
||||
|
@ -587,9 +591,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
|
||||
# Load model
|
||||
if pretrained_model_name_or_path is not None:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")):
|
||||
# Load from a TF 1.0 checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
|
||||
|
@ -622,8 +624,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
use_cdn=use_cdn,
|
||||
)
|
||||
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
|
@ -632,20 +634,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
|||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
if resolved_archive_file is None:
|
||||
raise EnvironmentError
|
||||
except EnvironmentError:
|
||||
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
||||
msg = "Couldn't reach server at '{}' to download pretrained weights.".format(archive_file)
|
||||
else:
|
||||
msg = (
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url to model weight files named one of {} but "
|
||||
"couldn't find any such file at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
", ".join(cls.pretrained_model_archive_map.keys()),
|
||||
archive_file,
|
||||
[WEIGHTS_NAME, TF2_WEIGHTS_NAME, TF_WEIGHTS_NAME],
|
||||
)
|
||||
)
|
||||
msg = (
|
||||
f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n"
|
||||
f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n"
|
||||
f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {WEIGHTS_NAME}, {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME}.\n\n"
|
||||
)
|
||||
raise EnvironmentError(msg)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
|
|
|
@ -34,18 +34,19 @@ from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_l
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xlm-mlm-en-2048": "https://cdn.huggingface.co/xlm-mlm-en-2048-pytorch_model.bin",
|
||||
"xlm-mlm-ende-1024": "https://cdn.huggingface.co/xlm-mlm-ende-1024-pytorch_model.bin",
|
||||
"xlm-mlm-enfr-1024": "https://cdn.huggingface.co/xlm-mlm-enfr-1024-pytorch_model.bin",
|
||||
"xlm-mlm-enro-1024": "https://cdn.huggingface.co/xlm-mlm-enro-1024-pytorch_model.bin",
|
||||
"xlm-mlm-tlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-tlm-xnli15-1024-pytorch_model.bin",
|
||||
"xlm-mlm-xnli15-1024": "https://cdn.huggingface.co/xlm-mlm-xnli15-1024-pytorch_model.bin",
|
||||
"xlm-clm-enfr-1024": "https://cdn.huggingface.co/xlm-clm-enfr-1024-pytorch_model.bin",
|
||||
"xlm-clm-ende-1024": "https://cdn.huggingface.co/xlm-clm-ende-1024-pytorch_model.bin",
|
||||
"xlm-mlm-17-1280": "https://cdn.huggingface.co/xlm-mlm-17-1280-pytorch_model.bin",
|
||||
"xlm-mlm-100-1280": "https://cdn.huggingface.co/xlm-mlm-100-1280-pytorch_model.bin",
|
||||
}
|
||||
XLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlm-mlm-en-2048",
|
||||
"xlm-mlm-ende-1024",
|
||||
"xlm-mlm-enfr-1024",
|
||||
"xlm-mlm-enro-1024",
|
||||
"xlm-mlm-tlm-xnli15-1024",
|
||||
"xlm-mlm-xnli15-1024",
|
||||
"xlm-clm-enfr-1024",
|
||||
"xlm-clm-ende-1024",
|
||||
"xlm-mlm-17-1280",
|
||||
"xlm-mlm-100-1280",
|
||||
# See all XLM models at https://huggingface.co/models?filter=xlm
|
||||
]
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
|
@ -207,7 +208,6 @@ class XLMPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XLMConfig
|
||||
pretrained_model_archive_map = XLM_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -31,14 +31,15 @@ from .modeling_roberta import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xlm-roberta-base": "https://cdn.huggingface.co/xlm-roberta-base-pytorch_model.bin",
|
||||
"xlm-roberta-large": "https://cdn.huggingface.co/xlm-roberta-large-pytorch_model.bin",
|
||||
"xlm-roberta-large-finetuned-conll02-dutch": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-dutch-pytorch_model.bin",
|
||||
"xlm-roberta-large-finetuned-conll02-spanish": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll02-spanish-pytorch_model.bin",
|
||||
"xlm-roberta-large-finetuned-conll03-english": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-english-pytorch_model.bin",
|
||||
"xlm-roberta-large-finetuned-conll03-german": "https://cdn.huggingface.co/xlm-roberta-large-finetuned-conll03-german-pytorch_model.bin",
|
||||
}
|
||||
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlm-roberta-base",
|
||||
"xlm-roberta-large",
|
||||
"xlm-roberta-large-finetuned-conll02-dutch",
|
||||
"xlm-roberta-large-finetuned-conll02-spanish",
|
||||
"xlm-roberta-large-finetuned-conll03-english",
|
||||
"xlm-roberta-large-finetuned-conll03-german",
|
||||
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
|
||||
]
|
||||
|
||||
|
||||
XLM_ROBERTA_START_DOCSTRING = r"""
|
||||
|
@ -65,7 +66,6 @@ class XLMRobertaModel(RobertaModel):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -78,7 +78,6 @@ class XLMRobertaForMaskedLM(RobertaForMaskedLM):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -93,7 +92,6 @@ class XLMRobertaForSequenceClassification(RobertaForSequenceClassification):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -108,7 +106,6 @@ class XLMRobertaForMultipleChoice(RobertaForMultipleChoice):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
|
@ -123,4 +120,3 @@ class XLMRobertaForTokenClassification(RobertaForTokenClassification):
|
|||
"""
|
||||
|
||||
config_class = XLMRobertaConfig
|
||||
pretrained_model_archive_map = XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
|
|
@ -32,10 +32,11 @@ from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogit
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xlnet-base-cased": "https://cdn.huggingface.co/xlnet-base-cased-pytorch_model.bin",
|
||||
"xlnet-large-cased": "https://cdn.huggingface.co/xlnet-large-cased-pytorch_model.bin",
|
||||
}
|
||||
XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xlnet-base-cased",
|
||||
"xlnet-large-cased",
|
||||
# See all XLNet models at https://huggingface.co/models?filter=xlnet
|
||||
]
|
||||
|
||||
|
||||
def build_tf_xlnet_to_pytorch_map(model, config, tf_weights=None):
|
||||
|
@ -459,7 +460,6 @@ class XLNetPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XLNetConfig
|
||||
pretrained_model_archive_map = XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xlnet
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -97,27 +97,24 @@ class AutoTokenizer:
|
|||
when created with the `AutoTokenizer.from_pretrained(pretrained_model_name_or_path)`
|
||||
class method.
|
||||
|
||||
The `from_pretrained()` method take care of returning the correct tokenizer class instance
|
||||
The `from_pretrained()` method takes care of returning the correct tokenizer class instance
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Tokenizer (T5 model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
|
||||
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
|
||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: T5Tokenizer (T5 model)
|
||||
- `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- `albert`: AlbertTokenizer (ALBERT model)
|
||||
- `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
|
||||
- `longformer`: LongformerTokenizer (AllenAI Longformer model)
|
||||
- `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- `bert`: BertTokenizer (Bert model)
|
||||
- `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- `xlm`: XLMTokenizer (XLM model)
|
||||
- `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
|
||||
This class cannot be instantiated using `__init__()` (throw an error).
|
||||
"""
|
||||
|
@ -133,24 +130,25 @@ class AutoTokenizer:
|
|||
r""" Instantiate one of the tokenizer classes of the library
|
||||
from a pre-trained model vocabulary.
|
||||
|
||||
The tokenizer class to instantiate is selected as the first pattern matching
|
||||
in the `pretrained_model_name_or_path` string (in the following order):
|
||||
- contains `t5`: T5Tokenizer (T5 model)
|
||||
- contains `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- contains `albert`: AlbertTokenizer (ALBERT model)
|
||||
- contains `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- contains `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
|
||||
- contains `longformer`: LongformerTokenizer (AllenAI Longformer model)
|
||||
- contains `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- contains `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
|
||||
- contains `bert`: BertTokenizer (Bert model)
|
||||
- contains `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- contains `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- contains `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- contains `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- contains `xlm`: XLMTokenizer (XLM model)
|
||||
- contains `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- contains `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
The tokenizer class to instantiate is selected
|
||||
based on the `model_type` property of the config object, or when it's missing,
|
||||
falling back to using pattern matching on the `pretrained_model_name_or_path` string:
|
||||
- `t5`: T5Tokenizer (T5 model)
|
||||
- `distilbert`: DistilBertTokenizer (DistilBert model)
|
||||
- `albert`: AlbertTokenizer (ALBERT model)
|
||||
- `camembert`: CamembertTokenizer (CamemBERT model)
|
||||
- `xlm-roberta`: XLMRobertaTokenizer (XLM-RoBERTa model)
|
||||
- `longformer`: LongformerTokenizer (AllenAI Longformer model)
|
||||
- `roberta`: RobertaTokenizer (RoBERTa model)
|
||||
- `bert-base-japanese`: BertJapaneseTokenizer (Bert model)
|
||||
- `bert`: BertTokenizer (Bert model)
|
||||
- `openai-gpt`: OpenAIGPTTokenizer (OpenAI GPT model)
|
||||
- `gpt2`: GPT2Tokenizer (OpenAI GPT-2 model)
|
||||
- `transfo-xl`: TransfoXLTokenizer (Transformer-XL model)
|
||||
- `xlnet`: XLNetTokenizer (XLNet model)
|
||||
- `xlm`: XLMTokenizer (XLM model)
|
||||
- `ctrl`: CTRLTokenizer (Salesforce CTRL model)
|
||||
- `electra`: ElectraTokenizer (Google ELECTRA model)
|
||||
|
||||
Params:
|
||||
pretrained_model_name_or_path: either:
|
||||
|
|
|
@ -47,9 +47,9 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
|||
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-vocab.txt",
|
||||
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-vocab.txt",
|
||||
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-vocab.txt",
|
||||
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
|
||||
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
|
||||
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/vocab.txt",
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/vocab.txt",
|
||||
"wietsedv/bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -69,9 +69,9 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
|||
"bert-base-cased-finetuned-mrpc": 512,
|
||||
"bert-base-german-dbmdz-cased": 512,
|
||||
"bert-base-german-dbmdz-uncased": 512,
|
||||
"bert-base-finnish-cased-v1": 512,
|
||||
"bert-base-finnish-uncased-v1": 512,
|
||||
"bert-base-dutch-cased": 512,
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
||||
"wietsedv/bert-base-dutch-cased": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
|
@ -90,9 +90,9 @@ PRETRAINED_INIT_CONFIGURATION = {
|
|||
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
||||
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
||||
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
||||
"bert-base-finnish-cased-v1": {"do_lower_case": False},
|
||||
"bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
||||
"bert-base-dutch-cased": {"do_lower_case": False},
|
||||
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
||||
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
||||
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -30,37 +30,37 @@ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
|||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt",
|
||||
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt",
|
||||
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt",
|
||||
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt",
|
||||
"cl-tohoku/bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese/vocab.txt",
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking/vocab.txt",
|
||||
"cl-tohoku/bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char/vocab.txt",
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking/vocab.txt",
|
||||
}
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
"bert-base-japanese": 512,
|
||||
"bert-base-japanese-whole-word-masking": 512,
|
||||
"bert-base-japanese-char": 512,
|
||||
"bert-base-japanese-char-whole-word-masking": 512,
|
||||
"cl-tohoku/bert-base-japanese": 512,
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking": 512,
|
||||
"cl-tohoku/bert-base-japanese-char": 512,
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking": 512,
|
||||
}
|
||||
|
||||
PRETRAINED_INIT_CONFIGURATION = {
|
||||
"bert-base-japanese": {
|
||||
"cl-tohoku/bert-base-japanese": {
|
||||
"do_lower_case": False,
|
||||
"word_tokenizer_type": "mecab",
|
||||
"subword_tokenizer_type": "wordpiece",
|
||||
},
|
||||
"bert-base-japanese-whole-word-masking": {
|
||||
"cl-tohoku/bert-base-japanese-whole-word-masking": {
|
||||
"do_lower_case": False,
|
||||
"word_tokenizer_type": "mecab",
|
||||
"subword_tokenizer_type": "wordpiece",
|
||||
},
|
||||
"bert-base-japanese-char": {
|
||||
"cl-tohoku/bert-base-japanese-char": {
|
||||
"do_lower_case": False,
|
||||
"word_tokenizer_type": "mecab",
|
||||
"subword_tokenizer_type": "character",
|
||||
},
|
||||
"bert-base-japanese-char-whole-word-masking": {
|
||||
"cl-tohoku/bert-base-japanese-char-whole-word-masking": {
|
||||
"do_lower_case": False,
|
||||
"word_tokenizer_type": "mecab",
|
||||
"subword_tokenizer_type": "character",
|
||||
|
|
|
@ -942,13 +942,11 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
|||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
if len(cls.vocab_files_names) > 1:
|
||||
raise ValueError(
|
||||
"Calling {}.from_pretrained() with the path to a single file or url is not supported."
|
||||
"Use a model identifier or the path to a directory instead.".format(cls.__name__)
|
||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not supported."
|
||||
"Use a model identifier or the path to a directory instead."
|
||||
)
|
||||
logger.warning(
|
||||
"Calling {}.from_pretrained() with the path to a single file or url is deprecated".format(
|
||||
cls.__name__
|
||||
)
|
||||
f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is deprecated"
|
||||
)
|
||||
file_id = list(cls.vocab_files_names.keys())[0]
|
||||
vocab_files[file_id] = pretrained_model_name_or_path
|
||||
|
|
|
@ -63,8 +63,6 @@ logger = logging.getLogger(__name__)
|
|||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
|
@ -411,7 +409,7 @@ def main():
|
|||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
|
|
@ -57,7 +57,6 @@ class XxxConfig(PretrainedConfig):
|
|||
initializing all weight matrices.
|
||||
layer_norm_eps: The epsilon used by LayerNorm.
|
||||
"""
|
||||
pretrained_config_archive_map = XXX_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "xxx"
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -32,13 +32,13 @@ from .modeling_tf_utils import TFPreTrainedModel, get_initializer, shape_list
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
# This list contrains shortcut names for some of
|
||||
# the pretrained weights provided with the models
|
||||
####################################################
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xxx-base-uncased": "https://cdn.huggingface.co/xxx-base-uncased-tf_model.h5",
|
||||
"xxx-large-uncased": "https://cdn.huggingface.co/xxx-large-uncased-tf_model.h5",
|
||||
}
|
||||
TF_XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xxx-base-uncased",
|
||||
"xxx-large-uncased",
|
||||
]
|
||||
|
||||
|
||||
####################################################
|
||||
|
@ -180,7 +180,6 @@ class TFXxxPreTrainedModel(TFPreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = TF_XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
||||
|
|
|
@ -34,13 +34,13 @@ from .modeling_utils import PreTrainedModel
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
####################################################
|
||||
# This dict contrains shortcut names and associated url
|
||||
# for the pretrained weights provided with the models
|
||||
# This list contrains shortcut names for some of
|
||||
# the pretrained weights provided with the models
|
||||
####################################################
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
"xxx-base-uncased": "https://cdn.huggingface.co/xxx-base-uncased-pytorch_model.bin",
|
||||
"xxx-large-uncased": "https://cdn.huggingface.co/xxx-large-uncased-pytorch_model.bin",
|
||||
}
|
||||
XXX_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"xxx-base-uncased",
|
||||
"xxx-large-uncased",
|
||||
]
|
||||
|
||||
|
||||
####################################################
|
||||
|
@ -180,7 +180,6 @@ class XxxPreTrainedModel(PreTrainedModel):
|
|||
"""
|
||||
|
||||
config_class = XxxConfig
|
||||
pretrained_model_archive_map = XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_xxx
|
||||
base_model_prefix = "transformer"
|
||||
|
||||
|
|
|
@ -32,7 +32,7 @@ if is_torch_available():
|
|||
XxxForSequenceClassification,
|
||||
XxxForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_xxx import XXX_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -269,6 +269,6 @@ class XxxModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(XXX_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in XXX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = XxxModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -33,7 +33,7 @@ if is_torch_available():
|
|||
AlbertForTokenClassification,
|
||||
AlbertForQuestionAnswering,
|
||||
)
|
||||
from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -295,6 +295,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = AlbertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -40,7 +40,7 @@ if is_torch_available():
|
|||
AutoModelForTokenClassification,
|
||||
BertForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_auto import (
|
||||
MODEL_MAPPING,
|
||||
MODEL_FOR_PRETRAINING_MAPPING,
|
||||
|
@ -56,7 +56,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -71,7 +71,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_model_for_pretraining_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -87,7 +87,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_lmhead_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -100,7 +100,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_sequence_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -115,7 +115,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_question_answering_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
@ -128,7 +128,7 @@ class AutoModelTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_token_classification_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
|
|
@ -39,7 +39,7 @@ if is_torch_available():
|
|||
MBartTokenizer,
|
||||
)
|
||||
from transformers.modeling_bart import (
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
shift_tokens_right,
|
||||
invert_mask,
|
||||
_prepare_bart_decoder_inputs,
|
||||
|
@ -261,7 +261,7 @@ class BartTranslationTests(unittest.TestCase):
|
|||
self.assertEqual(expected_translation_romanian, decoded[0])
|
||||
|
||||
def test_mbart_enro_config(self):
|
||||
mbart_models = ["mbart-large-en-ro"]
|
||||
mbart_models = ["facebook/mbart-large-en-ro"]
|
||||
expected = {"scale_embedding": True, "output_past": True}
|
||||
for name in mbart_models:
|
||||
config = BartConfig.from_pretrained(name)
|
||||
|
@ -561,7 +561,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||
@unittest.skip("This is just too slow")
|
||||
def test_model_from_pretrained(self):
|
||||
# Forces 1.6GB download from S3 for each model
|
||||
for model_name in list(BART_PRETRAINED_MODEL_ARCHIVE_MAP.keys()):
|
||||
for model_name in BART_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||
model = BartModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
@ -593,7 +593,7 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||
self.assertEqual(EXPECTED_SUMMARY, decoded[0])
|
||||
|
||||
def test_xsum_config_generation_params(self):
|
||||
config = BartConfig.from_pretrained("bart-large-xsum")
|
||||
config = BartConfig.from_pretrained("facebook/bart-large-xsum")
|
||||
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
|
||||
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
|
||||
self.assertDictEqual(expected_params, config_params)
|
||||
|
|
|
@ -35,7 +35,7 @@ if is_torch_available():
|
|||
BertForTokenClassification,
|
||||
BertForMultipleChoice,
|
||||
)
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
class BertModelTester:
|
||||
|
@ -494,6 +494,6 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = BertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -36,7 +36,7 @@ if is_torch_available():
|
|||
PreTrainedModel,
|
||||
BertModel,
|
||||
BertConfig,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
|
@ -824,7 +824,7 @@ class ModelUtilsTest(unittest.TestCase):
|
|||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
config = BertConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, PretrainedConfig)
|
||||
|
|
|
@ -24,7 +24,7 @@ from .utils import require_torch, slow, torch_device
|
|||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel
|
||||
from transformers import CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, CTRLLMHeadModel
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -210,7 +210,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in CTRL_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = CTRLModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -247,6 +247,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
# @slow
|
||||
# def test_model_from_pretrained(self):
|
||||
# for model_name in list(DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
# model = DistilBertModel.from_pretrained(model_name)
|
||||
# self.assertIsNotNone(model)
|
||||
|
|
|
@ -32,7 +32,7 @@ if is_torch_available():
|
|||
ElectraForPreTraining,
|
||||
ElectraForSequenceClassification,
|
||||
)
|
||||
from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_electra import ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -312,6 +312,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = ElectraModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -32,7 +32,7 @@ if is_torch_available():
|
|||
FlaubertForQuestionAnsweringSimple,
|
||||
FlaubertForSequenceClassification,
|
||||
)
|
||||
from transformers.modeling_flaubert import FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_flaubert import FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@require_torch
|
||||
|
@ -387,6 +387,6 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = FlaubertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
|
@ -28,7 +28,7 @@ if is_torch_available():
|
|||
from transformers import (
|
||||
GPT2Config,
|
||||
GPT2Model,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
GPT2LMHeadModel,
|
||||
GPT2DoubleHeadsModel,
|
||||
)
|
||||
|
@ -334,7 +334,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = GPT2Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ if is_torch_available():
|
|||
from transformers import (
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTModel,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTDoubleHeadsModel,
|
||||
)
|
||||
|
@ -218,7 +218,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = OpenAIGPTModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ if is_torch_available():
|
|||
ReformerModelWithLMHead,
|
||||
ReformerTokenizer,
|
||||
ReformerLayer,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
@ -503,7 +503,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = ReformerModelWithLMHead.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ if is_torch_available():
|
|||
RobertaForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_roberta import RobertaEmbeddings, RobertaForMultipleChoice, RobertaForQuestionAnswering
|
||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
|
||||
|
@ -273,7 +273,7 @@ class RobertaModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = RobertaModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from .utils import require_torch, slow, torch_device
|
|||
if is_torch_available():
|
||||
import torch
|
||||
from transformers import T5Config, T5Model, T5ForConditionalGeneration
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
from transformers.modeling_t5 import T5_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.tokenization_t5 import T5Tokenizer
|
||||
|
||||
|
||||
|
@ -372,7 +372,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = T5Model.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ if is_tf_available():
|
|||
TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TFAlbertForQuestionAnswering,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
)
|
||||
|
||||
|
||||
|
@ -257,6 +257,6 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
for model_name in TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
model = TFAlbertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue