2020-12-08 08:55:12 +08:00
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2023-08-17 13:58:35 +08:00
"""
Utility that checks the big table in the file docs / source / en / index . md and potentially updates it .
2020-12-08 08:55:12 +08:00
2023-08-17 13:58:35 +08:00
Use from the root of the repo with :
` ` ` bash
python utils / check_inits . py
` ` `
for a check that will error in case of inconsistencies ( used by ` make repo - consistency ` ) .
To auto - fix issues run :
` ` ` bash
python utils / check_inits . py - - fix_and_overwrite
` ` `
which is used by ` make fix - copies ` .
"""
2020-12-08 08:55:12 +08:00
import argparse
import collections
import os
import re
2023-08-17 13:58:35 +08:00
from typing import List
2023-02-17 00:32:32 +08:00
from transformers . utils import direct_transformers_import
2020-12-08 08:55:12 +08:00
# All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_table.py
TRANSFORMERS_PATH = " src/transformers "
2022-04-04 22:25:46 +08:00
PATH_TO_DOCS = " docs/source/en "
2020-12-08 08:55:12 +08:00
REPO_PATH = " . "
2023-08-17 13:58:35 +08:00
def _find_text_in_file ( filename : str , start_prompt : str , end_prompt : str ) - > str :
2020-12-08 08:55:12 +08:00
"""
2023-08-17 13:58:35 +08:00
Find the text in filename between two prompts .
Args :
filename ( ` str ` ) : The file to search into .
start_prompt ( ` str ` ) : A string to look for at the start of the content searched .
end_prompt ( ` str ` ) : A string that will mark the end of the content to look for .
Returns :
` str ` : The content between the prompts .
2020-12-08 08:55:12 +08:00
"""
with open ( filename , " r " , encoding = " utf-8 " , newline = " \n " ) as f :
lines = f . readlines ( )
2023-08-17 13:58:35 +08:00
2020-12-08 08:55:12 +08:00
# Find the start prompt.
start_index = 0
while not lines [ start_index ] . startswith ( start_prompt ) :
start_index + = 1
start_index + = 1
2023-08-17 13:58:35 +08:00
# Now go until the end prompt.
2020-12-08 08:55:12 +08:00
end_index = start_index
while not lines [ end_index ] . startswith ( end_prompt ) :
end_index + = 1
end_index - = 1
while len ( lines [ start_index ] ) < = 1 :
start_index + = 1
while len ( lines [ end_index ] ) < = 1 :
end_index - = 1
end_index + = 1
return " " . join ( lines [ start_index : end_index ] ) , start_index , end_index , lines
2023-08-17 13:58:35 +08:00
# Regexes that match TF/Flax/PT model names. Add here suffixes that are used to identify models, separated by |
2020-12-08 08:55:12 +08:00
_re_tf_models = re . compile ( r " TF(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration) " )
_re_flax_models = re . compile ( r " Flax(.*)(?:Model|Encoder|Decoder|ForConditionalGeneration) " )
2023-08-17 13:58:35 +08:00
# Will match any TF or Flax model too so need to be in an else branch after the two previous regexes.
2020-12-08 08:55:12 +08:00
_re_pt_models = re . compile ( r " (.*)(?:Model|Encoder|Decoder|ForConditionalGeneration) " )
2021-10-06 10:43:16 +08:00
# This is to make sure the transformers module imported is the one in the repo.
2023-02-17 00:32:32 +08:00
transformers_module = direct_transformers_import ( TRANSFORMERS_PATH )
2021-10-06 10:43:16 +08:00
2023-08-17 13:58:35 +08:00
def camel_case_split ( identifier : str ) - > List [ str ] :
"""
Split a camel - cased name into words .
Args :
identifier ( ` str ` ) : The camel - cased name to parse .
Returns :
` List [ str ] ` : The list of words in the identifier ( as seprated by capital letters ) .
Example :
` ` ` py
>> > camel_case_split ( " CamelCasedClass " )
[ " Camel " , " Cased " , " Class " ]
` ` `
"""
# Regex thanks to https://stackoverflow.com/questions/29916065/how-to-do-camelcase-split-in-python
2020-12-08 08:55:12 +08:00
matches = re . finditer ( " .+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$) " , identifier )
return [ m . group ( 0 ) for m in matches ]
2023-08-17 13:58:35 +08:00
def _center_text ( text : str , width : int ) - > str :
"""
Utility that will add spaces on the left and right of a text to make it centered for a given width .
Args :
text ( ` str ` ) : The text to center .
width ( ` int ` ) : The desired length of the result .
Returns :
` str ` : A text of length ` width ` with the original ` text ` in the middle .
"""
2020-12-08 08:55:12 +08:00
text_length = 2 if text == " ✅ " or text == " ❌ " else len ( text )
left_indent = ( width - text_length ) / / 2
right_indent = width - text_length - left_indent
return " " * left_indent + text + " " * right_indent
2023-10-05 22:20:41 +08:00
SPECIAL_MODEL_NAME_LINK_MAPPING = {
" Data2VecAudio " : " [Data2VecAudio](model_doc/data2vec) " ,
" Data2VecText " : " [Data2VecText](model_doc/data2vec) " ,
" Data2VecVision " : " [Data2VecVision](model_doc/data2vec) " ,
" DonutSwin " : " [DonutSwin](model_doc/donut) " ,
}
MODEL_NAMES_WITH_SAME_CONFIG = {
" BARThez " : " BART " ,
" BARTpho " : " BART " ,
" BertJapanese " : " BERT " ,
" BERTweet " : " BERT " ,
" BORT " : " BERT " ,
" ByT5 " : " T5 " ,
" CPM " : " OpenAI GPT-2 " ,
" DePlot " : " Pix2Struct " ,
" DialoGPT " : " OpenAI GPT-2 " ,
" DiT " : " BEiT " ,
" FLAN-T5 " : " T5 " ,
" FLAN-UL2 " : " T5 " ,
" HerBERT " : " BERT " ,
" LayoutXLM " : " LayoutLMv2 " ,
" Llama2 " : " LLaMA " ,
2023-11-28 21:19:50 +08:00
" MADLAD-400 " : " T5 " ,
2023-10-05 22:20:41 +08:00
" MatCha " : " Pix2Struct " ,
" mBART-50 " : " mBART " ,
" Megatron-GPT2 " : " OpenAI GPT-2 " ,
" mLUKE " : " LUKE " ,
" MMS " : " Wav2Vec2 " ,
" NLLB " : " M2M100 " ,
" PhoBERT " : " BERT " ,
" T5v1.1 " : " T5 " ,
" TAPEX " : " BART " ,
" UL2 " : " T5 " ,
" Wav2Vec2Phoneme " : " Wav2Vec2 " ,
" XLM-V " : " XLM-RoBERTa " ,
" XLS-R " : " Wav2Vec2 " ,
" XLSR-Wav2Vec2 " : " Wav2Vec2 " ,
}
2024-02-20 20:08:31 +08:00
MODEL_NAMES_TO_IGNORE = [ " CLIPVisionModel " , " SiglipVisionModel " , " ChineseCLIPVisionModel " ]
2023-10-05 22:20:41 +08:00
2023-08-17 13:58:35 +08:00
def get_model_table_from_auto_modules ( ) - > str :
"""
Generates an up - to - date model table from the content of the auto modules .
"""
2020-12-08 08:55:12 +08:00
# Dictionary model names to config.
2021-10-06 10:43:16 +08:00
config_maping_names = transformers_module . models . auto . configuration_auto . CONFIG_MAPPING_NAMES
2020-12-08 08:55:12 +08:00
model_name_to_config = {
2021-08-06 19:12:30 +08:00
name : config_maping_names [ code ]
2021-10-06 10:43:16 +08:00
for code , name in transformers_module . MODEL_NAMES_MAPPING . items ( )
2021-08-06 19:12:30 +08:00
if code in config_maping_names
2020-12-08 08:55:12 +08:00
}
2021-08-06 19:12:30 +08:00
model_name_to_prefix = { name : config . replace ( " Config " , " " ) for name , config in model_name_to_config . items ( ) }
2020-12-08 08:55:12 +08:00
2023-08-17 13:58:35 +08:00
# Dictionaries flagging if each model prefix has a backend in PT/TF/Flax.
2020-12-08 08:55:12 +08:00
pt_models = collections . defaultdict ( bool )
tf_models = collections . defaultdict ( bool )
flax_models = collections . defaultdict ( bool )
# Let's lookup through all transformers object (once).
2021-10-06 10:43:16 +08:00
for attr_name in dir ( transformers_module ) :
2020-12-08 08:55:12 +08:00
lookup_dict = None
2023-07-21 21:41:36 +08:00
if _re_tf_models . match ( attr_name ) is not None :
2020-12-08 08:55:12 +08:00
lookup_dict = tf_models
attr_name = _re_tf_models . match ( attr_name ) . groups ( ) [ 0 ]
elif _re_flax_models . match ( attr_name ) is not None :
lookup_dict = flax_models
attr_name = _re_flax_models . match ( attr_name ) . groups ( ) [ 0 ]
elif _re_pt_models . match ( attr_name ) is not None :
lookup_dict = pt_models
attr_name = _re_pt_models . match ( attr_name ) . groups ( ) [ 0 ]
if lookup_dict is not None :
while len ( attr_name ) > 0 :
if attr_name in model_name_to_prefix . values ( ) :
lookup_dict [ attr_name ] = True
break
# Try again after removing the last word in the name
attr_name = " " . join ( camel_case_split ( attr_name ) [ : - 1 ] )
# Let's build that table!
2023-10-05 22:20:41 +08:00
model_names = list ( model_name_to_config . keys ( ) ) + list ( MODEL_NAMES_WITH_SAME_CONFIG . keys ( ) )
# model name to doc link mapping
model_names_mapping = transformers_module . models . auto . configuration_auto . MODEL_NAMES_MAPPING
model_name_to_link_mapping = { value : f " [ { value } ](model_doc/ { key } ) " for key , value in model_names_mapping . items ( ) }
# update mapping with special model names
model_name_to_link_mapping = {
k : SPECIAL_MODEL_NAME_LINK_MAPPING [ k ] if k in SPECIAL_MODEL_NAME_LINK_MAPPING else v
for k , v in model_name_to_link_mapping . items ( )
}
2023-09-25 21:41:59 +08:00
# MaskFormerSwin and TimmBackbone are backbones and so not meant to be loaded and used on their own. Instead, they define architectures which can be loaded using the AutoBackbone API.
2023-10-05 22:20:41 +08:00
names_to_exclude = [ " MaskFormerSwin " , " TimmBackbone " , " Speech2Text2 " ]
2023-09-25 21:41:59 +08:00
model_names = [ name for name in model_names if name not in names_to_exclude ]
2021-08-16 21:45:19 +08:00
model_names . sort ( key = str . lower )
2023-09-25 21:41:59 +08:00
2023-07-21 21:41:36 +08:00
columns = [ " Model " , " PyTorch support " , " TensorFlow support " , " Flax Support " ]
2020-12-08 08:55:12 +08:00
# We'll need widths to properly display everything in the center (+2 is to leave one extra space on each side).
2023-10-05 22:20:41 +08:00
2020-12-08 08:55:12 +08:00
widths = [ len ( c ) + 2 for c in columns ]
2023-10-05 22:20:41 +08:00
widths [ 0 ] = max ( [ len ( doc_link ) for doc_link in model_name_to_link_mapping . values ( ) ] ) + 2
2020-12-08 08:55:12 +08:00
2021-12-02 03:13:02 +08:00
# Build the table per se
table = " | " + " | " . join ( [ _center_text ( c , w ) for c , w in zip ( columns , widths ) ] ) + " | \n "
2022-01-10 19:44:11 +08:00
# Use ":-----:" format to center-aligned table cell texts
table + = " | " + " | " . join ( [ " : " + " - " * ( w - 2 ) + " : " for w in widths ] ) + " | \n "
2020-12-08 08:55:12 +08:00
check = { True : " ✅ " , False : " ❌ " }
2023-10-05 22:20:41 +08:00
2020-12-08 08:55:12 +08:00
for name in model_names :
2023-12-07 16:30:47 +08:00
if name in MODEL_NAMES_TO_IGNORE :
continue
2023-10-05 22:20:41 +08:00
if name in MODEL_NAMES_WITH_SAME_CONFIG . keys ( ) :
prefix = model_name_to_prefix [ MODEL_NAMES_WITH_SAME_CONFIG [ name ] ]
else :
prefix = model_name_to_prefix [ name ]
2020-12-08 08:55:12 +08:00
line = [
2023-10-05 22:20:41 +08:00
model_name_to_link_mapping [ name ] ,
2020-12-08 08:55:12 +08:00
check [ pt_models [ prefix ] ] ,
check [ tf_models [ prefix ] ] ,
check [ flax_models [ prefix ] ] ,
]
table + = " | " + " | " . join ( [ _center_text ( l , w ) for l , w in zip ( line , widths ) ] ) + " | \n "
return table
def check_model_table ( overwrite = False ) :
2023-08-17 13:58:35 +08:00
"""
Check the model table in the index . md is consistent with the state of the lib and potentially fix it .
Args :
overwrite ( ` bool ` , * optional * , defaults to ` False ` ) :
Whether or not to overwrite the table when it ' s not up to date.
"""
2020-12-08 08:55:12 +08:00
current_table , start_index , end_index , lines = _find_text_in_file (
2023-06-21 06:07:47 +08:00
filename = os . path . join ( PATH_TO_DOCS , " index.md " ) ,
2021-12-02 03:13:02 +08:00
start_prompt = " <!--This table is updated automatically from the auto modules " ,
end_prompt = " <!-- End table--> " ,
2020-12-08 08:55:12 +08:00
)
new_table = get_model_table_from_auto_modules ( )
if current_table != new_table :
if overwrite :
2023-06-21 06:07:47 +08:00
with open ( os . path . join ( PATH_TO_DOCS , " index.md " ) , " w " , encoding = " utf-8 " , newline = " \n " ) as f :
2020-12-08 08:55:12 +08:00
f . writelines ( lines [ : start_index ] + [ new_table ] + lines [ end_index : ] )
else :
raise ValueError (
2023-06-21 06:07:47 +08:00
" The model table in the `index.md` has not been updated. Run `make fix-copies` to fix this. "
2020-12-08 08:55:12 +08:00
)
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
parser . add_argument ( " --fix_and_overwrite " , action = " store_true " , help = " Whether to fix inconsistencies. " )
args = parser . parse_args ( )
check_model_table ( args . fix_and_overwrite )