Add _CHECKPOINT_FOR_DOC to all models (#12811)

* Add _CHECKPOINT_FOR_DOC

* Update src/transformers/models/funnel/modeling_funnel.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Lysandre Debut 2021-07-21 14:29:43 +02:00 committed by GitHub
parent 786ced3639
commit ac3cb660ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 26 additions and 7 deletions

View File

@ -51,6 +51,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BlenderbotConfig"
_TOKENIZER_FOR_DOC = "BlenderbotTokenizer"
_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill"
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = [

View File

@ -48,6 +48,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "BlenderbotSmallConfig"
_TOKENIZER_FOR_DOC = "BlenderbotSmallTokenizer"
_CHECKPOINT_FOR_DOC = "facebook/blenderbot_small-90M"
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = [

View File

@ -36,6 +36,7 @@ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
"openai/clip-vit-base-patch32",

View File

@ -41,6 +41,7 @@ from .configuration_deit import DeiTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DeiTConfig"
_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/deit-base-distilled-patch16-224",

View File

@ -52,6 +52,7 @@ if is_timm_available():
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DetrConfig"
_CHECKPOINT_FOR_DOC = "facebook/detr-resnet-50"
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/detr-resnet-50",

View File

@ -37,6 +37,7 @@ from .configuration_dpr import DPRConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DPRConfig"
_CHECKPOINT_FOR_DOC = "facebook/dpr-ctx_encoder-single-nq-base"
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/dpr-ctx_encoder-single-nq-base",

View File

@ -48,6 +48,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "FunnelConfig"
_TOKENIZER_FOR_DOC = "FunnelTokenizer"
_CHECKPOINT_FOR_DOC = "funnel-transformer/small"
FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [
"funnel-transformer/small", # B4-4-4H768
@ -987,7 +988,7 @@ class FunnelModel(FunnelPreTrainedModel):
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="funnel-transformer/small",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1174,7 +1175,7 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="funnel-transformer/small",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=MaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
mask="<mask>",
@ -1424,7 +1425,7 @@ class FunnelForTokenClassification(FunnelPreTrainedModel):
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="funnel-transformer/small",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
@ -1506,7 +1507,7 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="funnel-transformer/small",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=QuestionAnsweringModelOutput,
config_class=_CONFIG_FOR_DOC,
)

View File

@ -34,6 +34,7 @@ from .configuration_hubert import HubertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "HubertConfig"
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/hubert-base-ls960",

View File

@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LayoutLMConfig"
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased"
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
"layoutlm-base-uncased",

View File

@ -39,6 +39,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LukeConfig"
_TOKENIZER_FOR_DOC = "LukeTokenizer"
_CHECKPOINT_FOR_DOC = "studio-ousia/luke-base"
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = [
"studio-ousia/luke-base",

View File

@ -46,6 +46,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "M2M100Config"
_TOKENIZER_FOR_DOC = "M2M100Tokenizer"
_CHECKPOINT_FOR_DOC = "facebook/m2m100_418M"
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -1117,7 +1118,7 @@ class M2M100Model(M2M100PreTrainedModel):
@add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/m2m100_418M",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)

View File

@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MarianConfig"
_TOKENIZER_FOR_DOC = "MarianTokenizer"
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [

View File

@ -42,6 +42,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ProphenetConfig"
_TOKENIZER_FOR_DOC = "ProphetNetTokenizer"
_CHECKPOINT_FOR_DOC = "microsoft/prophetnet-large-uncased"
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
"microsoft/prophetnet-large-uncased",

View File

@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Speech2TextConfig"
_TOKENIZER_FOR_DOC = "Speech2TextTokenizer"
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@ -1131,7 +1132,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="s2t_transformer_s",
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)

View File

@ -50,6 +50,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
_CHECKPOINT_FOR_DOC = "t5-small"
####################################################
# This dict contains ids and associated url

View File

@ -54,6 +54,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "TapasConfig"
_TOKENIZER_FOR_DOC = "TapasTokenizer"
_TOKENIZER_FOR_DOC = "google/tapas-base"
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = [
# large models

View File

@ -50,6 +50,7 @@ from .configuration_visual_bert import VisualBertConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "VisualBertConfig"
_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"uclanlp/visualbert-vqa",

View File

@ -34,9 +34,10 @@ from .configuration_vit import ViTConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ViTConfig"
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
"nielsr/vit-base-patch16-224",
"google/vit-base-patch16-224",
# See all ViT models at https://huggingface.co/models?filter=vit
]

View File

@ -40,6 +40,7 @@ from .configuration_wav2vec2 import Wav2Vec2Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Wav2Vec2Config"
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
"facebook/wav2vec2-base-960h",