Doc checks (#25408)

* Document check_dummies

* Type hints and doc in other files

* Document check inits

* Add documentation to

* Address review comments
This commit is contained in:
Sylvain Gugger 2023-08-10 10:53:22 +02:00 committed by GitHub
parent b14d4641f6
commit 16edf4d9fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 459 additions and 224 deletions

View File

@ -40,6 +40,7 @@ import argparse
import glob
import os
import re
from typing import List, Optional, Tuple
import black
from doc_builder.style_doc import style_docstrings_in_code
@ -125,14 +126,22 @@ LOCALIZED_READMES = {
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
def _should_continue(line, indent):
def _should_continue(line: str, indent: str) -> bool:
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
# function definition
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def find_code_in_transformers(object_name):
"""Find and return the code source code of `object_name`."""
def find_code_in_transformers(object_name: str) -> str:
"""
Find and return the source code of an object.
Args:
object_name (`str`): The name of the object we want the source code of.
Returns:
`str`: The source code of the object.
"""
parts = object_name.split(".")
i = 0
@ -181,7 +190,16 @@ _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
def get_indent(code):
def get_indent(code: str) -> str:
"""
Find the indent in the first non empty line in a code sample.
Args:
code (`str`): The code to inspect.
Returns:
`str`: The indent looked at (as string).
"""
lines = code.split("\n")
idx = 0
while idx < len(lines) and len(lines[idx]) == 0:
@ -191,9 +209,15 @@ def get_indent(code):
return ""
def blackify(code):
def blackify(code: str) -> str:
"""
Applies the black part of our `make style` command to `code`.
Applies the black part of our `make style` command to some code.
Args:
code (`str`): The code to format.
Returns:
`str`: The formatted code.
"""
has_indent = len(get_indent(code)) > 0
if has_indent:
@ -204,14 +228,22 @@ def blackify(code):
return result[len("class Bla:\n") :] if has_indent else result
def check_codes_match(observed_code, theoretical_code):
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
"""
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
Checks if two version of a code match with the exception of the class/function name.
Args:
observed_code (`str`): The code found.
theoretical_code (`str`): The code to match.
Returns:
`Optional[int]`: The index of the first line where there is a difference (if any) and `None` if the codes
match.
"""
observed_code_header = observed_code.split("\n")[0]
theoretical_code_header = theoretical_code.split("\n")[0]
# Catch the function/class name: it is expected that those do not match.
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
for re_pattern in [_re_class_match, _re_func_match]:
@ -220,6 +252,7 @@ def check_codes_match(observed_code, theoretical_code):
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
# Find the first diff. Line 0 is special since we need to compare with the function/class names ignored.
diff_index = 0
if theoretical_code_header != observed_code_header:
return 0
@ -231,11 +264,19 @@ def check_codes_match(observed_code, theoretical_code):
diff_index += 1
def is_copy_consistent(filename, overwrite=False):
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
"""
Check if the code commented as a copy in `filename` matches the original.
Check if the code commented as a copy in a file matches the original.
Return the differences or overwrites the content depending on `overwrite`.
Args:
filename (`str`):
The name of the file to check.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
Returns:
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
with the name of the object having a diff and the line number where theere is the first diff.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
@ -308,8 +349,12 @@ def is_copy_consistent(filename, overwrite=False):
def check_copies(overwrite: bool = False):
"""
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
model list in the main README and other READMEs/index.md are consistent.
Check every file is copy-consistent with the original. Also check the model list in the main README and other
READMEs/index.md are consistent.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
diffs = []
@ -328,8 +373,11 @@ def check_copies(overwrite: bool = False):
def check_full_copies(overwrite: bool = False):
"""
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
`overwrite` to fix issues.
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
diffs = []
for target, source in FULL_COPIES.items():
@ -354,8 +402,18 @@ def check_full_copies(overwrite: bool = False):
)
def get_model_list(filename, start_prompt, end_prompt):
"""Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str:
"""
Extracts the model list from a README.
Args:
filename (`str`): The name of the README file to check.
start_prompt (`str`): The string to look for that introduces the model list.
end_prompt (`str`): The string to look for that ends the model list.
Returns:
`str`: The model list.
"""
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Find the start of the list.
@ -368,6 +426,7 @@ def get_model_list(filename, start_prompt, end_prompt):
current_line = ""
end_index = start_index
# Keep going until the end of the list.
while not lines[end_index].startswith(end_prompt):
if lines[end_index].startswith("1."):
if len(current_line) > 1:
@ -382,7 +441,7 @@ def get_model_list(filename, start_prompt, end_prompt):
return "".join(result)
def convert_to_localized_md(model_list, localized_model_list, format_str):
def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]:
"""
Compare the model list from the main README to the one in a localized README.
@ -458,19 +517,33 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
def convert_readme_to_index(model_list):
def convert_readme_to_index(model_list: str) -> str:
"""
Converts the model list of the README to the index.md format.
Converts the model list of the README to the index.md format (adapting links to the doc to relative links).
Args:
model_list (`str`): The model list of the main README.
Returns:
`str`: The model list in the format for the index.
"""
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
return model_list.replace("https://huggingface.co/docs/transformers/", "")
def _find_text_in_file(filename, start_prompt, end_prompt):
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
"""
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
lines.
Find the text in a file between two prompts.
Args:
filename (`str`): The name of the file to look into.
start_prompt (`str`): The string to look for that introduces the content looked for.
end_prompt (`str`): The string to look for that ends the content looked for.
Returns:
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
original file, the index of the end line in the original file and the list of lines of that file.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
@ -493,9 +566,13 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
return "".join(lines[start_index:end_index]), start_index, end_index, lines
def check_model_list_copy(overwrite=False, max_per_line=119):
def check_model_list_copy(overwrite: bool = False):
"""
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
"""
# Fix potential doc links in the README
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
@ -526,6 +603,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
)
# Buld the converted Markdown.
converted_md_lists = []
for filename, value in LOCALIZED_READMES.items():
_start_prompt = value["start_prompt"]
@ -537,6 +615,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
# Build the converted index and compare it.
converted_md_list = convert_readme_to_index(md_list)
if converted_md_list != index_list:
if overwrite:
@ -548,6 +627,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
"`make fix-copies` to fix this."
)
# Compare the converted Markdowns
for converted_md_list in converted_md_lists:
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
@ -606,10 +686,13 @@ README_TEMPLATE = (
)
def check_readme(overwrite=False):
def check_readme(overwrite: bool = False):
"""
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
missing models using `README_TEMPLATE`.
Check if the main README contains all the models in the library or not.
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to add an entry for the missing models using `README_TEMPLATE`.
"""
info = LOCALIZED_READMES["README.md"]
models, start_index, end_index, lines = _find_text_in_file(

View File

@ -34,6 +34,7 @@ python utils/check_doc_toc.py --fix_and_overwrite
import argparse
from collections import defaultdict
from typing import List
import yaml
@ -41,7 +42,7 @@ import yaml
PATH_TO_TOC = "docs/source/en/_toctree.yml"
def clean_model_doc_toc(model_doc):
def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]:
"""
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
and sorting models alphabetically.
@ -77,7 +78,7 @@ def clean_model_doc_toc(model_doc):
return sorted(new_doc, key=lambda s: s["title"].lower())
def check_model_doc(overwrite=False):
def check_model_doc(overwrite: bool = False):
"""
Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model
API doc) and potentially auto-cleans it.

View File

@ -40,7 +40,16 @@ REPO_PATH = "."
DOCTEST_FILE_PATHS = ["documentation_tests.txt", "slow_documentation_tests.txt"]
def clean_doctest_list(doctest_file, overwrite=False):
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
"""
Cleans the doctest in a given file.
Args:
doctest_file (`str`):
The path to the doctest file to check or clean.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to fix problems. If `False`, will error when the file is not clean.
"""
non_existent_paths = []
all_paths = []
with open(doctest_file, "r", encoding="utf-8") as f:

View File

@ -12,10 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
to access one of their methods.
Usage (from the root of the repo):
Check that the dummy files are up to date (used in `make repo-consistency`):
```bash
python utils/check_dummies.py
```
Update the dummy files if needed (used in `make fix-copies`):
```bash
python utils/check_dummies.py --fix_and_overwrite
```
"""
import argparse
import os
import re
from typing import Dict, List, Optional
# All paths are set with the intent you should run this script from the root of the repo with the command
@ -26,13 +47,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
# Matches from xxx import bla
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
# Matches if not is_xxx_available()
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
# Template for the dummy objects.
DUMMY_CONSTANT = """
{0} = None
"""
DUMMY_CLASS = """
class {0}(metaclass=DummyObject):
_backends = {1}
@ -48,8 +72,18 @@ def {0}(*args, **kwargs):
"""
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
def find_backend(line: str) -> Optional[str]:
"""
Find one (or multiple) backend in a code line of the init.
Args:
line (`str`): A code line in an init file.
Returns:
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
`xxx_and_yyy` for instance).
"""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
@ -57,8 +91,13 @@ def find_backend(line):
return "_and_".join(backends)
def read_init():
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
def read_init() -> Dict[str, List[str]]:
"""
Read the init and extract backend-specific objects.
Returns:
Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
"""
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
@ -83,8 +122,10 @@ def read_init():
line = lines[line_index]
single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
# Single-line imports
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
# Multiple-line imports (with 3 indent level)
objects.append(line[12:-2])
line_index += 1
@ -95,8 +136,17 @@ def read_init():
return backend_specific_objects
def create_dummy_object(name, backend_name):
"""Create the code for the dummy object corresponding to `name`."""
def create_dummy_object(name: str, backend_name: str) -> str:
"""
Create the code for a dummy object.
Args:
name (`str`): The name of the object.
backend_name (`str`): The name of the backend required for that object.
Returns:
`str`: The code of the dummy object.
"""
if name.isupper():
return DUMMY_CONSTANT.format(name)
elif name.islower():
@ -105,11 +155,21 @@ def create_dummy_object(name, backend_name):
return DUMMY_CLASS.format(name, backend_name)
def create_dummy_files(backend_specific_objects=None):
"""Create the content of the dummy files."""
def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:
"""
Create the content of the dummy files.
Args:
backend_specific_objects (`Dict[str, List[str]]`, *optional*):
The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
`read_init()`.
Returns:
`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
"""
if backend_specific_objects is None:
backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
dummy_files = {}
for backend, objects in backend_specific_objects.items():
@ -122,10 +182,17 @@ def create_dummy_files(backend_specific_objects=None):
return dummy_files
def check_dummies(overwrite=False):
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
def check_dummies(overwrite: bool = False):
"""
Check if the dummy files are up to date and maybe `overwrite` with the right content.
Args:
overwrite (`bool`, *optional*, default to `False`):
Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
when `overwrite=False`.
"""
dummy_files = create_dummy_files()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
@ -143,6 +210,7 @@ def check_dummies(overwrite=False):
else:
actual_dummies[backend] = ""
# Compare actual with what they should be.
for backend in dummy_files.keys():
if dummy_files[backend] != actual_dummies[backend]:
if overwrite:

View File

@ -12,13 +12,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this
script is to check the objects defined in both halves are the same.
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule
won't be importable.
Use from the root of the repo with:
```bash
python utils/check_inits.py
```
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
There is no auto-fix possible here sadly :-(
"""
import collections
import os
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
# Path is set with the intent you should run this script from the root of the repo.
PATH_TO_TRANSFORMERS = "src/transformers"
@ -46,8 +70,18 @@ _re_try = re.compile(r"^\s*try:")
_re_else = re.compile(r"^\s*else:")
def find_backend(line):
"""Find one (or multiple) backend in a code line of the init."""
def find_backend(line: str) -> Optional[str]:
"""
Find one (or multiple) backend in a code line of the init.
Args:
line (`str`): A code line of the main init.
Returns:
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
`xxx_and_yyy` for instance).
"""
if _re_test_backend.search(line) is None:
return None
backends = [b[0] for b in _re_backend.findall(line)]
@ -55,14 +89,23 @@ def find_backend(line):
return "_and_".join(backends)
def parse_init(init_file):
def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]:
"""
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
defined
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
defined.
Args:
init_file (`str`): Path to the init file to inspect.
Returns:
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the
init. Returns `None` if the init is not a custom init.
"""
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
# Get the to `_import_structure` definition.
line_index = 0
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
line_index += 1
@ -91,7 +134,9 @@ def parse_init(init_file):
objects.append(line[9:-3])
line_index += 1
# Those are stored with the key "none".
import_dict_objects = {"none": objects}
# Let's continue with backend-specific objects in _import_structure
while not lines[line_index].startswith("if TYPE_CHECKING"):
# If the line is an if not is_backend_available, we grab all objects associated.
@ -151,6 +196,7 @@ def parse_init(init_file):
line_index += 1
type_hint_objects = {"none": objects}
# Let's continue with backend-specific objects
while line_index < len(lines):
# If the line is an if is_backend_available, we grab all objects associated.
@ -186,19 +232,33 @@ def parse_init(init_file):
return import_dict_objects, type_hint_objects
def analyze_results(import_dict_objects, type_hint_objects):
def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects: Dict[str, List[str]]) -> List[str]:
"""
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
Args:
import_dict_objects (`Dict[str, List[str]]`):
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
list of imported objects.
type_hint_objects (`Dict[str, List[str]]`):
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
list of imported objects.
Returns:
`List[str]`: The list of errors corresponding to mismatches.
"""
def find_duplicates(seq):
return [k for k, v in collections.Counter(seq).items() if v > 1]
# If one backend is missing from the other part of the init, error early.
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
return ["Both sides of the init do not have the same backends!"]
errors = []
# Find all errors.
for key in import_dict_objects.keys():
# Duplicate imports in any half.
duplicate_imports = find_duplicates(import_dict_objects[key])
if duplicate_imports:
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
@ -206,6 +266,7 @@ def analyze_results(import_dict_objects, type_hint_objects):
if duplicate_type_hints:
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
# Missing imports in either part of the init.
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
name = "base imports" if key == "none" else f"{key} backend"
errors.append(f"Differences for {name}:")
@ -237,7 +298,7 @@ def check_all_inits():
raise ValueError("\n\n".join(failures))
def get_transformers_submodules():
def get_transformers_submodules() -> List[str]:
"""
Returns the list of Transformers submodules.
"""
@ -272,6 +333,9 @@ IGNORE_SUBMODULES = [
def check_submodules():
"""
Check all submodules of Transformers are properly registered in the main init. Error otherwise.
"""
# This is to make sure the transformers module imported is the one in the repo.
from transformers.utils import direct_transformers_import

View File

@ -12,15 +12,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utility that performs several consistency checks on the repo. This includes:
- checking all models are properly defined in the __init__ of models/
- checking all models are in the main __init__
- checking all models are properly tested
- checking all object in the main __init__ are documented
- checking all models are in at least one auto class
- checking all the auto mapping are properly defined (no typos, importable)
- checking the list of deprecated models is up to date
Use from the root of the repo with (as used in `make repo-consistency`):
```bash
python utils/check_repo.py
```
It has no auto-fix mode.
"""
import inspect
import os
import re
import sys
import types
import warnings
from collections import OrderedDict
from difflib import get_close_matches
from pathlib import Path
from typing import List, Tuple
from transformers import is_flax_available, is_tf_available, is_torch_available
from transformers.models.auto import get_values
@ -60,91 +79,25 @@ PRIVATE_MODELS = [
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
"NllbMoeDecoder",
"NllbMoeEncoder",
"UMT5EncoderModel", # Building part of bigger (tested) model.
"LlamaDecoder", # Building part of bigger (tested) model.
"Blip2QFormerModel", # Building part of bigger (tested) model.
"DetaEncoder", # Building part of bigger (tested) model.
"DetaDecoder", # Building part of bigger (tested) model.
"ErnieMForInformationExtraction",
"GraphormerEncoder", # Building part of bigger (tested) model.
"GraphormerDecoderHead", # Building part of bigger (tested) model.
"CLIPSegDecoder", # Building part of bigger (tested) model.
"TableTransformerEncoder", # Building part of bigger (tested) model.
"TableTransformerDecoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
"InformerEncoder", # Building part of bigger (tested) model.
"InformerDecoder", # Building part of bigger (tested) model.
"AutoformerEncoder", # Building part of bigger (tested) model.
"AutoformerDecoder", # Building part of bigger (tested) model.
"JukeboxVQVAE", # Building part of bigger (tested) model.
"JukeboxPrior", # Building part of bigger (tested) model.
"DeformableDetrEncoder", # Building part of bigger (tested) model.
"DeformableDetrDecoder", # Building part of bigger (tested) model.
"OPTDecoder", # Building part of bigger (tested) model.
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
"WhisperDecoder", # Building part of bigger (tested) model.
"WhisperEncoder", # Building part of bigger (tested) model.
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
"SegformerDecodeHead", # Building part of bigger (tested) model.
"PLBartEncoder", # Building part of bigger (tested) model.
"PLBartDecoder", # Building part of bigger (tested) model.
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
"DetrEncoder", # Building part of bigger (tested) model.
"DetrDecoder", # Building part of bigger (tested) model.
"DetrDecoderWrapper", # Building part of bigger (tested) model.
"ConditionalDetrEncoder", # Building part of bigger (tested) model.
"ConditionalDetrDecoder", # Building part of bigger (tested) model.
"M2M100Encoder", # Building part of bigger (tested) model.
"M2M100Decoder", # Building part of bigger (tested) model.
"MCTCTEncoder", # Building part of bigger (tested) model.
"MgpstrModel", # Building part of bigger (tested) model.
"Speech2TextEncoder", # Building part of bigger (tested) model.
"Speech2TextDecoder", # Building part of bigger (tested) model.
"LEDEncoder", # Building part of bigger (tested) model.
"LEDDecoder", # Building part of bigger (tested) model.
"BartDecoderWrapper", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder.
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
"BlenderbotEncoder", # Building part of bigger (tested) model.
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
"MBartEncoder", # Building part of bigger (tested) model.
"MBartDecoderWrapper", # Building part of bigger (tested) model.
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
"MegatronBertEncoder", # Building part of bigger (tested) model.
"MegatronBertDecoder", # Building part of bigger (tested) model.
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
"MusicgenDecoder", # Building part of bigger (tested) model.
"MvpDecoderWrapper", # Building part of bigger (tested) model.
"MvpEncoder", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
"PegasusXEncoder", # Building part of bigger (tested) model.
"PegasusXDecoder", # Building part of bigger (tested) model.
"PegasusXDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"RealmBertModel", # Building part of bigger (tested) model.
"RealmReader", # Not regular model.
"RealmScorer", # Not regular model.
"RealmForOpenQA", # Not regular model.
"ReformerForMaskedLM", # Needs to be setup as decoder.
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix
"TFRobertaPreLayerNormForMultipleChoice", # TODO: fix
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
"TFWhisperEncoder", # Building part of bigger (tested) model.
"TFWhisperDecoder", # Building part of bigger (tested) model.
"SeparableConv1D", # Building part of bigger (tested) model.
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
@ -155,18 +108,6 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
"SpeechT5Decoder", # Building part of bigger (tested) model.
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5Encoder", # Building part of bigger (tested) model.
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
"BarkCausalModel", # Building part of bigger (tested) model.
"BarkModel", # Does not have a forward signature - generation tested with integration tests
]
@ -236,12 +177,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"AutoformerForPrediction",
"JukeboxVQVAE",
"JukeboxPrior",
"PegasusXEncoder",
"PegasusXDecoder",
"PegasusXDecoderWrapper",
"PegasusXEncoder",
"PegasusXDecoder",
"PegasusXDecoderWrapper",
"SamModel",
"DPTForDepthEstimation",
"DecisionTransformerGPT2Model",
@ -250,17 +185,11 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"ViltForImageAndTextRetrieval",
"ViltForTokenClassification",
"ViltForMaskedLM",
"XGLMEncoder",
"XGLMDecoder",
"XGLMDecoderWrapper",
"PerceiverForMultimodalAutoencoding",
"PerceiverForOpticalFlow",
"SegformerDecodeHead",
"TFSegformerDecodeHead",
"FlaxBeitForMaskedImageModeling",
"PLBartEncoder",
"PLBartDecoder",
"PLBartDecoderWrapper",
"BeitForMaskedImageModeling",
"ChineseCLIPTextModel",
"ChineseCLIPVisionModel",
@ -347,7 +276,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
]
# DO NOT edit this list!
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove)
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"FlaxBertLayer",
"FlaxBigBirdLayer",
@ -361,8 +290,7 @@ OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
"TFViTMAELayer",
]
# Update this list for models that have multiple model types for the same
# model doc
# Update this list for models that have multiple model types for the same model doc.
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
[
("data2vec-text", "data2vec"),
@ -378,6 +306,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
def check_missing_backends():
"""
Checks if all backends are installed (otherwise the check of this script is incomplete). Will error in the CI if
that's not the case but only throw a warning for users running this.
"""
missing_backends = []
if not is_torch_available():
missing_backends.append("PyTorch")
@ -402,7 +334,9 @@ def check_missing_backends():
def check_model_list():
"""Check the model list inside the transformers library."""
"""
Checks the model listed as subfolders of `models` match the models available in `transformers.models`.
"""
# Get the models from the directory structure of `src/transformers/models/`
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
_models = []
@ -413,7 +347,7 @@ def check_model_list():
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
_models.append(model)
# Get the models from the directory structure of `src/transformers/models/`
# Get the models in the submodule `transformers.models`
models = [model for model in dir(transformers.models) if not model.startswith("__")]
missing_models = sorted(set(_models).difference(models))
@ -425,8 +359,8 @@ def check_model_list():
# If some modeling modules should be ignored for all checks, they should be added in the nested list
# _ignore_modules of this function.
def get_model_modules():
"""Get the model modules inside the transformers library."""
def get_model_modules() -> List[str]:
"""Get all the model modules inside the transformers library (except deprecated models)."""
_ignore_modules = [
"modeling_auto",
"modeling_encoder_decoder",
@ -454,21 +388,32 @@ def get_model_modules():
]
modules = []
for model in dir(transformers.models):
if model == "deprecated":
continue
# There are some magic dunder attributes in the dir, we ignore them
if not model.startswith("__"):
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
if model == "deprecated" or model.startswith("__"):
continue
model_module = getattr(transformers.models, model)
for submodule in dir(model_module):
if submodule.startswith("modeling") and submodule not in _ignore_modules:
modeling_module = getattr(model_module, submodule)
if inspect.ismodule(modeling_module):
modules.append(modeling_module)
return modules
def get_models(module, include_pretrained=False):
"""Get the objects in module that are models."""
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]:
"""
Get the objects in a module that are models.
Args:
module (`types.ModuleType`):
The module from which we are extracting models.
include_pretrained (`bool`, *optional*, defaults to `False`):
Whether or not to include the `PreTrainedModel` subclass (like `BertPreTrainedModel`) or not.
Returns:
List[Tuple[str, type]]: List of models as tuples (class name, actual class).
"""
models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module):
@ -480,12 +425,10 @@ def get_models(module, include_pretrained=False):
return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
def is_building_block(model: str) -> bool:
"""
Returns `True` if a model is a building block part of a bigger model.
"""
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
@ -494,7 +437,13 @@ def is_a_private_model(model):
return True
if model.endswith("Prenet"):
return True
return False
def is_a_private_model(model: str) -> bool:
"""Returns `True` if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
return is_building_block(model)
def check_models_are_in_init():
@ -514,11 +463,14 @@ def check_models_are_in_init():
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function.
def get_model_test_files():
"""Get the model test files.
def get_model_test_files() -> List[str]:
"""
Get the model test files.
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
Returns:
`List[str]`: The list of test files. The returned files will NOT contain the `tests` (i.e. `PATH_TO_TESTS`
defined in this script). They will be considered as paths relative to `tests`. A caller has to use
`os.path.join(PATH_TO_TESTS, ...)` to access the files.
"""
_ignore_files = [
@ -531,7 +483,6 @@ def get_model_test_files():
"test_modeling_tf_encoder_decoder",
]
test_files = []
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
model_test_root = os.path.join(PATH_TO_TESTS, "models")
model_test_dirs = []
for x in os.listdir(model_test_root):
@ -553,9 +504,17 @@ def get_model_test_files():
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
# for the all_model_classes variable.
def find_tested_models(test_file):
"""Parse the content of test_file to detect what's in all_model_classes"""
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
def find_tested_models(test_file: str) -> List[str]:
"""
Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
the common test class.
Args:
test_file (`str`): The path to the test file to check
Returns:
`List[str]`: The list of models tested in that file.
"""
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read()
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
@ -571,8 +530,25 @@ def find_tested_models(test_file):
return model_tested
def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file."""
def should_be_tested(model_name: str) -> bool:
"""
Whether or not a model should be tested.
"""
if model_name in IGNORE_NON_TESTED:
return False
return not is_building_block(model_name)
def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]:
"""Check models defined in a module are all tested in a given file.
Args:
module (`types.ModuleType`): The module in which we get the models.
test_file (`str`): The path to the file where the module is tested.
Returns:
`List[str]`: The list of error messages corresponding to models not tested.
"""
# XxxPreTrainedModel are not tested
defined_models = get_models(module)
tested_models = find_tested_models(test_file)
@ -586,7 +562,7 @@ def check_models_are_tested(module, test_file):
]
failures = []
for model_name, _ in defined_models:
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
if model_name not in tested_models and should_be_tested(model_name):
failures.append(
f"{model_name} is defined in {module.__name__} but is not tested in "
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
@ -602,6 +578,7 @@ def check_all_models_are_tested():
test_files = get_model_test_files()
failures = []
for module in modules:
# Matches a module to its test file.
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
if len(test_file) == 0:
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
@ -616,7 +593,7 @@ def check_all_models_are_tested():
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
def get_all_auto_configured_models():
def get_all_auto_configured_models() -> List[str]:
"""Return the list of all models in at least one auto class."""
result = set() # To avoid duplicates we concatenate all model classes in a set.
if is_torch_available():
@ -634,8 +611,8 @@ def get_all_auto_configured_models():
return list(result)
def ignore_unautoclassed(model_name):
"""Rules to determine if `name` should be in an auto class."""
def ignore_unautoclassed(model_name: str) -> bool:
"""Rules to determine if a model should be in an auto class."""
# Special white list
if model_name in IGNORE_NON_AUTO_CONFIGURED:
return True
@ -645,8 +622,19 @@ def ignore_unautoclassed(model_name):
return False
def check_models_are_auto_configured(module, all_auto_models):
"""Check models defined in module are each in an auto class."""
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]:
"""
Check models defined in module are each in an auto class.
Args:
module (`types.ModuleType`):
The module in which we get the models.
all_auto_models (`List[str]`):
The list of all models in an auto class (as obtained with `get_all_auto_configured_models()`).
Returns:
`List[str]`: The list of error messages corresponding to models not tested.
"""
defined_models = get_models(module)
failures = []
for model_name, _ in defined_models:
@ -661,6 +649,7 @@ def check_models_are_auto_configured(module, all_auto_models):
def check_all_models_are_auto_configured():
"""Check all models are each in an auto class."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
modules = get_model_modules()
all_auto_models = get_all_auto_configured_models()
@ -675,6 +664,7 @@ def check_all_models_are_auto_configured():
def check_all_auto_object_names_being_defined():
"""Check all names defined in auto (name) mappings exist in the library."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@ -695,7 +685,7 @@ def check_all_auto_object_names_being_defined():
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, mapping in mappings_to_check.items():
for model_type, class_names in mapping.items():
for _, class_names in mapping.items():
if not isinstance(class_names, tuple):
class_names = (class_names,)
for class_name in class_names:
@ -716,6 +706,7 @@ def check_all_auto_object_names_being_defined():
def check_all_auto_mapping_names_in_config_mapping_names():
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@ -736,7 +727,7 @@ def check_all_auto_mapping_names_in_config_mapping_names():
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, mapping in mappings_to_check.items():
for model_type, class_names in mapping.items():
for model_type in mapping:
if model_type not in CONFIG_MAPPING_NAMES:
failures.append(
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
@ -747,7 +738,8 @@ def check_all_auto_mapping_names_in_config_mapping_names():
def check_all_auto_mappings_importable():
"""Check all auto mappings could be imported."""
"""Check all auto mappings can be imported."""
# This is where we need to check we have all backends or the check is incomplete.
check_missing_backends()
failures = []
@ -761,7 +753,7 @@ def check_all_auto_mappings_importable():
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
for name, _ in mappings_to_check.items():
for name in mappings_to_check:
name = name.replace("_MAPPING_NAMES", "_MAPPING")
if not hasattr(transformers, name):
failures.append(f"`{name}`")
@ -770,44 +762,46 @@ def check_all_auto_mappings_importable():
def check_objects_being_equally_in_main_init():
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
"""
Check if a (TensorFlow or Flax) object is in the main __init__ iif its counterpart in PyTorch is.
"""
attrs = dir(transformers)
failures = []
for attr in attrs:
obj = getattr(transformers, attr)
if hasattr(obj, "__module__"):
module_path = obj.__module__
if "models.deprecated" in module_path:
continue
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__:
continue
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
module_path = obj.__module__
module_name = module_path.split(".")[-1]
module_dir = ".".join(module_path.split(".")[:-1])
if (
module_name.startswith("modeling_")
and not module_name.startswith("modeling_tf_")
and not module_name.startswith("modeling_flax_")
):
parent_module = sys.modules[module_dir]
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
frameworks = []
if is_tf_available():
frameworks.append("TF")
if is_flax_available():
frameworks.append("Flax")
for framework in frameworks:
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
other_module = getattr(parent_module, other_module_name)
if hasattr(other_module, f"{framework}{attr}"):
if not hasattr(transformers, f"{framework}{attr}"):
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}{attr}")
if hasattr(other_module, f"{framework}_{attr}"):
if not hasattr(transformers, f"{framework}_{attr}"):
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
failures.append(f"{framework}_{attr}")
if len(failures) > 0:
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
@ -815,8 +809,16 @@ def check_objects_being_equally_in_main_init():
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
def check_decorator_order(filename):
"""Check that in the test file `filename` the slow decorator is always last."""
def check_decorator_order(filename: str) -> List[int]:
"""
Check that in a given test file, the slow decorator is always last.
Args:
filename (`str`): The path to a test file to check.
Returns:
`List[int]`: The list of failures as a list of indices where there are problems.
"""
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
decorator_before = None
@ -849,8 +851,13 @@ def check_all_decorator_order():
)
def find_all_documented_objects():
"""Parse the content of all doc files to detect which classes and functions it documents"""
def find_all_documented_objects() -> List[str]:
"""
Parse the content of all doc files to detect which classes and functions it documents.
Returns:
`List[str]`: The list of all object names being documented.
"""
documented_obj = []
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
@ -959,8 +966,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
]
def ignore_undocumented(name):
"""Rules to determine if `name` should be undocumented."""
def ignore_undocumented(name: str) -> bool:
"""Rules to determine if `name` should be undocumented (returns `True` if it should not be documented)."""
# NOT DOCUMENTED ON PURPOSE.
# Constants uppercase are not documented.
if name.isupper():
@ -1047,7 +1054,7 @@ _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
def is_rst_docstring(docstring):
def is_rst_docstring(docstring: str) -> True:
"""
Returns `True` if `docstring` is written in rst.
"""
@ -1061,7 +1068,7 @@ def is_rst_docstring(docstring):
def check_docstrings_are_in_md():
"""Check all docstrings are in md"""
"""Check all docstrings are written in md and nor rst."""
files_with_rst = []
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
with open(file, encoding="utf-8") as f:
@ -1084,6 +1091,9 @@ def check_docstrings_are_in_md():
def check_deprecated_constant_is_up_to_date():
"""
Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date.
"""
deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated")
deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]