Add _CHECKPOINT_FOR_DOC to all models (#12811)
* Add _CHECKPOINT_FOR_DOC * Update src/transformers/models/funnel/modeling_funnel.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
786ced3639
commit
ac3cb660ca
|
@ -51,6 +51,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "BlenderbotConfig"
|
||||
_TOKENIZER_FOR_DOC = "BlenderbotTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/blenderbot-400M-distill"
|
||||
|
||||
|
||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
|
|
@ -48,6 +48,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "BlenderbotSmallConfig"
|
||||
_TOKENIZER_FOR_DOC = "BlenderbotSmallTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/blenderbot_small-90M"
|
||||
|
||||
|
||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
|
|
@ -36,6 +36,7 @@ from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
|||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CHECKPOINT_FOR_DOC = "openai/clip-vit-base-patch32"
|
||||
|
||||
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"openai/clip-vit-base-patch32",
|
||||
|
|
|
@ -41,6 +41,7 @@ from .configuration_deit import DeiTConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "DeiTConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/deit-base-distilled-patch16-224"
|
||||
|
||||
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/deit-base-distilled-patch16-224",
|
||||
|
|
|
@ -52,6 +52,7 @@ if is_timm_available():
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "DetrConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/detr-resnet-50"
|
||||
|
||||
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/detr-resnet-50",
|
||||
|
|
|
@ -37,6 +37,7 @@ from .configuration_dpr import DPRConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "DPRConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/dpr-ctx_encoder-single-nq-base"
|
||||
|
||||
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/dpr-ctx_encoder-single-nq-base",
|
||||
|
|
|
@ -48,6 +48,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "FunnelConfig"
|
||||
_TOKENIZER_FOR_DOC = "FunnelTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "funnel-transformer/small"
|
||||
|
||||
FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"funnel-transformer/small", # B4-4-4H768
|
||||
|
@ -987,7 +988,7 @@ class FunnelModel(FunnelPreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="funnel-transformer/small",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=BaseModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
@ -1174,7 +1175,7 @@ class FunnelForMaskedLM(FunnelPreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="funnel-transformer/small",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=MaskedLMOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
mask="<mask>",
|
||||
|
@ -1424,7 +1425,7 @@ class FunnelForTokenClassification(FunnelPreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="funnel-transformer/small",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=TokenClassifierOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
@ -1506,7 +1507,7 @@ class FunnelForQuestionAnswering(FunnelPreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(FUNNEL_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="funnel-transformer/small",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=QuestionAnsweringModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
|
|
@ -34,6 +34,7 @@ from .configuration_hubert import HubertConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "HubertConfig"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/hubert-base-ls960"
|
||||
|
||||
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/hubert-base-ls960",
|
||||
|
|
|
@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "LayoutLMConfig"
|
||||
_TOKENIZER_FOR_DOC = "LayoutLMTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/layoutlm-base-uncased"
|
||||
|
||||
LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"layoutlm-base-uncased",
|
||||
|
|
|
@ -39,6 +39,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "LukeConfig"
|
||||
_TOKENIZER_FOR_DOC = "LukeTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "studio-ousia/luke-base"
|
||||
|
||||
LUKE_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"studio-ousia/luke-base",
|
||||
|
|
|
@ -46,6 +46,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "M2M100Config"
|
||||
_TOKENIZER_FOR_DOC = "M2M100Tokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/m2m100_418M"
|
||||
|
||||
|
||||
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
@ -1117,7 +1118,7 @@ class M2M100Model(M2M100PreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="facebook/m2m100_418M",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
|
|
@ -49,6 +49,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "MarianConfig"
|
||||
_TOKENIZER_FOR_DOC = "MarianTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "Helsinki-NLP/opus-mt-en-de"
|
||||
|
||||
|
||||
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
|
|
@ -42,6 +42,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "ProphenetConfig"
|
||||
_TOKENIZER_FOR_DOC = "ProphetNetTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "microsoft/prophetnet-large-uncased"
|
||||
|
||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"microsoft/prophetnet-large-uncased",
|
||||
|
|
|
@ -45,6 +45,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "Speech2TextConfig"
|
||||
_TOKENIZER_FOR_DOC = "Speech2TextTokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
|
||||
|
||||
|
||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
|
@ -1131,7 +1132,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel):
|
|||
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="s2t_transformer_s",
|
||||
checkpoint=_CHECKPOINT_FOR_DOC,
|
||||
output_type=Seq2SeqModelOutput,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
|
|
|
@ -50,6 +50,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "T5Config"
|
||||
_TOKENIZER_FOR_DOC = "T5Tokenizer"
|
||||
_CHECKPOINT_FOR_DOC = "t5-small"
|
||||
|
||||
####################################################
|
||||
# This dict contains ids and associated url
|
||||
|
|
|
@ -54,6 +54,7 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
_CONFIG_FOR_DOC = "TapasConfig"
|
||||
_TOKENIZER_FOR_DOC = "TapasTokenizer"
|
||||
_TOKENIZER_FOR_DOC = "google/tapas-base"
|
||||
|
||||
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
# large models
|
||||
|
|
|
@ -50,6 +50,7 @@ from .configuration_visual_bert import VisualBertConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "VisualBertConfig"
|
||||
_CHECKPOINT_FOR_DOC = "uclanlp/visualbert-vqa-coco-pre"
|
||||
|
||||
VISUAL_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"uclanlp/visualbert-vqa",
|
||||
|
|
|
@ -34,9 +34,10 @@ from .configuration_vit import ViTConfig
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "ViTConfig"
|
||||
_CHECKPOINT_FOR_DOC = "google/vit-base-patch16-224"
|
||||
|
||||
VIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"nielsr/vit-base-patch16-224",
|
||||
"google/vit-base-patch16-224",
|
||||
# See all ViT models at https://huggingface.co/models?filter=vit
|
||||
]
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from .configuration_wav2vec2 import Wav2Vec2Config
|
|||
logger = logging.get_logger(__name__)
|
||||
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-base-960h"
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/wav2vec2-base-960h",
|
||||
|
|
Loading…
Reference in New Issue