Doc checks (#25408)
* Document check_dummies * Type hints and doc in other files * Document check inits * Add documentation to * Address review comments
This commit is contained in:
parent
b14d4641f6
commit
16edf4d9fd
|
@ -40,6 +40,7 @@ import argparse
|
|||
import glob
|
||||
import os
|
||||
import re
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import black
|
||||
from doc_builder.style_doc import style_docstrings_in_code
|
||||
|
@ -125,14 +126,22 @@ LOCALIZED_READMES = {
|
|||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
def _should_continue(line, indent):
|
||||
def _should_continue(line: str, indent: str) -> bool:
|
||||
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
|
||||
# function definition
|
||||
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name):
|
||||
"""Find and return the code source code of `object_name`."""
|
||||
def find_code_in_transformers(object_name: str) -> str:
|
||||
"""
|
||||
Find and return the source code of an object.
|
||||
|
||||
Args:
|
||||
object_name (`str`): The name of the object we want the source code of.
|
||||
|
||||
Returns:
|
||||
`str`: The source code of the object.
|
||||
"""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
||||
|
@ -181,7 +190,16 @@ _re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)")
|
|||
_re_fill_pattern = re.compile(r"<FILL\s+[^>]*>")
|
||||
|
||||
|
||||
def get_indent(code):
|
||||
def get_indent(code: str) -> str:
|
||||
"""
|
||||
Find the indent in the first non empty line in a code sample.
|
||||
|
||||
Args:
|
||||
code (`str`): The code to inspect.
|
||||
|
||||
Returns:
|
||||
`str`: The indent looked at (as string).
|
||||
"""
|
||||
lines = code.split("\n")
|
||||
idx = 0
|
||||
while idx < len(lines) and len(lines[idx]) == 0:
|
||||
|
@ -191,9 +209,15 @@ def get_indent(code):
|
|||
return ""
|
||||
|
||||
|
||||
def blackify(code):
|
||||
def blackify(code: str) -> str:
|
||||
"""
|
||||
Applies the black part of our `make style` command to `code`.
|
||||
Applies the black part of our `make style` command to some code.
|
||||
|
||||
Args:
|
||||
code (`str`): The code to format.
|
||||
|
||||
Returns:
|
||||
`str`: The formatted code.
|
||||
"""
|
||||
has_indent = len(get_indent(code)) > 0
|
||||
if has_indent:
|
||||
|
@ -204,14 +228,22 @@ def blackify(code):
|
|||
return result[len("class Bla:\n") :] if has_indent else result
|
||||
|
||||
|
||||
def check_codes_match(observed_code, theoretical_code):
|
||||
def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int]:
|
||||
"""
|
||||
Checks if the code in `observed_code` and `theoretical_code` match with the exception of the class/function name.
|
||||
Returns the index of the first line where there is a difference (if any) and `None` if the codes match.
|
||||
Checks if two version of a code match with the exception of the class/function name.
|
||||
|
||||
Args:
|
||||
observed_code (`str`): The code found.
|
||||
theoretical_code (`str`): The code to match.
|
||||
|
||||
Returns:
|
||||
`Optional[int]`: The index of the first line where there is a difference (if any) and `None` if the codes
|
||||
match.
|
||||
"""
|
||||
observed_code_header = observed_code.split("\n")[0]
|
||||
theoretical_code_header = theoretical_code.split("\n")[0]
|
||||
|
||||
# Catch the function/class name: it is expected that those do not match.
|
||||
_re_class_match = re.compile(r"class\s+([^\(:]+)(?:\(|:)")
|
||||
_re_func_match = re.compile(r"def\s+([^\(]+)\(")
|
||||
for re_pattern in [_re_class_match, _re_func_match]:
|
||||
|
@ -220,6 +252,7 @@ def check_codes_match(observed_code, theoretical_code):
|
|||
theoretical_name = re_pattern.search(theoretical_code_header).groups()[0]
|
||||
theoretical_code_header = theoretical_code_header.replace(theoretical_name, observed_obj_name)
|
||||
|
||||
# Find the first diff. Line 0 is special since we need to compare with the function/class names ignored.
|
||||
diff_index = 0
|
||||
if theoretical_code_header != observed_code_header:
|
||||
return 0
|
||||
|
@ -231,11 +264,19 @@ def check_codes_match(observed_code, theoretical_code):
|
|||
diff_index += 1
|
||||
|
||||
|
||||
def is_copy_consistent(filename, overwrite=False):
|
||||
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
|
||||
"""
|
||||
Check if the code commented as a copy in `filename` matches the original.
|
||||
Check if the code commented as a copy in a file matches the original.
|
||||
|
||||
Return the differences or overwrites the content depending on `overwrite`.
|
||||
Args:
|
||||
filename (`str`):
|
||||
The name of the file to check.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
|
||||
Returns:
|
||||
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
|
||||
with the name of the object having a diff and the line number where theere is the first diff.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
@ -308,8 +349,12 @@ def is_copy_consistent(filename, overwrite=False):
|
|||
|
||||
def check_copies(overwrite: bool = False):
|
||||
"""
|
||||
Check every file is copy-consistent with the original and maybe `overwrite` content when it is not. Also check the
|
||||
model list in the main README and other READMEs/index.md are consistent.
|
||||
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
||||
READMEs/index.md are consistent.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
"""
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
diffs = []
|
||||
|
@ -328,8 +373,11 @@ def check_copies(overwrite: bool = False):
|
|||
|
||||
def check_full_copies(overwrite: bool = False):
|
||||
"""
|
||||
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent and maybe
|
||||
`overwrite` to fix issues.
|
||||
Check the files that are full copies of others (as indicated in `FULL_COPIES`) are copy-consistent.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
"""
|
||||
diffs = []
|
||||
for target, source in FULL_COPIES.items():
|
||||
|
@ -354,8 +402,18 @@ def check_full_copies(overwrite: bool = False):
|
|||
)
|
||||
|
||||
|
||||
def get_model_list(filename, start_prompt, end_prompt):
|
||||
"""Extracts the model list from a README, between `start_prompt` and `end_prompt`."""
|
||||
def get_model_list(filename: str, start_prompt: str, end_prompt: str) -> str:
|
||||
"""
|
||||
Extracts the model list from a README.
|
||||
|
||||
Args:
|
||||
filename (`str`): The name of the README file to check.
|
||||
start_prompt (`str`): The string to look for that introduces the model list.
|
||||
end_prompt (`str`): The string to look for that ends the model list.
|
||||
|
||||
Returns:
|
||||
`str`: The model list.
|
||||
"""
|
||||
with open(os.path.join(REPO_PATH, filename), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
# Find the start of the list.
|
||||
|
@ -368,6 +426,7 @@ def get_model_list(filename, start_prompt, end_prompt):
|
|||
current_line = ""
|
||||
end_index = start_index
|
||||
|
||||
# Keep going until the end of the list.
|
||||
while not lines[end_index].startswith(end_prompt):
|
||||
if lines[end_index].startswith("1."):
|
||||
if len(current_line) > 1:
|
||||
|
@ -382,7 +441,7 @@ def get_model_list(filename, start_prompt, end_prompt):
|
|||
return "".join(result)
|
||||
|
||||
|
||||
def convert_to_localized_md(model_list, localized_model_list, format_str):
|
||||
def convert_to_localized_md(model_list: str, localized_model_list: str, format_str: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
Compare the model list from the main README to the one in a localized README.
|
||||
|
||||
|
@ -458,19 +517,33 @@ def convert_to_localized_md(model_list, localized_model_list, format_str):
|
|||
return readmes_match, "\n".join((x[1] for x in sorted_index)) + "\n"
|
||||
|
||||
|
||||
def convert_readme_to_index(model_list):
|
||||
def convert_readme_to_index(model_list: str) -> str:
|
||||
"""
|
||||
Converts the model list of the README to the index.md format.
|
||||
Converts the model list of the README to the index.md format (adapting links to the doc to relative links).
|
||||
|
||||
Args:
|
||||
model_list (`str`): The model list of the main README.
|
||||
|
||||
Returns:
|
||||
`str`: The model list in the format for the index.
|
||||
"""
|
||||
# We need to replce both link to the main doc and stable doc (the order of the next two instructions is important).
|
||||
model_list = model_list.replace("https://huggingface.co/docs/transformers/main/", "")
|
||||
return model_list.replace("https://huggingface.co/docs/transformers/", "")
|
||||
|
||||
|
||||
def _find_text_in_file(filename, start_prompt, end_prompt):
|
||||
def _find_text_in_file(filename: str, start_prompt: str, end_prompt: str) -> Tuple[str, int, int, List[str]]:
|
||||
"""
|
||||
Find the text in `filename` between a line beginning with `start_prompt` and before `end_prompt`, removing empty
|
||||
lines.
|
||||
Find the text in a file between two prompts.
|
||||
|
||||
Args:
|
||||
filename (`str`): The name of the file to look into.
|
||||
start_prompt (`str`): The string to look for that introduces the content looked for.
|
||||
end_prompt (`str`): The string to look for that ends the content looked for.
|
||||
|
||||
Returns:
|
||||
Tuple[str, int, int, List[str]]: The content between the two prompts, the index of the start line in the
|
||||
original file, the index of the end line in the original file and the list of lines of that file.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
@ -493,9 +566,13 @@ def _find_text_in_file(filename, start_prompt, end_prompt):
|
|||
return "".join(lines[start_index:end_index]), start_index, end_index, lines
|
||||
|
||||
|
||||
def check_model_list_copy(overwrite=False, max_per_line=119):
|
||||
def check_model_list_copy(overwrite: bool = False):
|
||||
"""
|
||||
Check the model lists in the README is consistent with the ones in the other READMES and also with `index.nmd`.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
"""
|
||||
# Fix potential doc links in the README
|
||||
with open(os.path.join(REPO_PATH, "README.md"), "r", encoding="utf-8", newline="\n") as f:
|
||||
|
@ -526,6 +603,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||
end_prompt=LOCALIZED_READMES["README.md"]["end_prompt"],
|
||||
)
|
||||
|
||||
# Buld the converted Markdown.
|
||||
converted_md_lists = []
|
||||
for filename, value in LOCALIZED_READMES.items():
|
||||
_start_prompt = value["start_prompt"]
|
||||
|
@ -537,6 +615,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||
|
||||
converted_md_lists.append((filename, readmes_match, converted_md_list, _start_prompt, _end_prompt))
|
||||
|
||||
# Build the converted index and compare it.
|
||||
converted_md_list = convert_readme_to_index(md_list)
|
||||
if converted_md_list != index_list:
|
||||
if overwrite:
|
||||
|
@ -548,6 +627,7 @@ def check_model_list_copy(overwrite=False, max_per_line=119):
|
|||
"`make fix-copies` to fix this."
|
||||
)
|
||||
|
||||
# Compare the converted Markdowns
|
||||
for converted_md_list in converted_md_lists:
|
||||
filename, readmes_match, converted_md, _start_prompt, _end_prompt = converted_md_list
|
||||
|
||||
|
@ -606,10 +686,13 @@ README_TEMPLATE = (
|
|||
)
|
||||
|
||||
|
||||
def check_readme(overwrite=False):
|
||||
def check_readme(overwrite: bool = False):
|
||||
"""
|
||||
Check if the main README contains all the models in the library or not. If `overwrite`, will add an entry for the
|
||||
missing models using `README_TEMPLATE`.
|
||||
Check if the main README contains all the models in the library or not.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to add an entry for the missing models using `README_TEMPLATE`.
|
||||
"""
|
||||
info = LOCALIZED_READMES["README.md"]
|
||||
models, start_index, end_index, lines = _find_text_in_file(
|
||||
|
|
|
@ -34,6 +34,7 @@ python utils/check_doc_toc.py --fix_and_overwrite
|
|||
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
import yaml
|
||||
|
||||
|
@ -41,7 +42,7 @@ import yaml
|
|||
PATH_TO_TOC = "docs/source/en/_toctree.yml"
|
||||
|
||||
|
||||
def clean_model_doc_toc(model_doc):
|
||||
def clean_model_doc_toc(model_doc: List[dict]) -> List[dict]:
|
||||
"""
|
||||
Cleans a section of the table of content of the model documentation (one specific modality) by removing duplicates
|
||||
and sorting models alphabetically.
|
||||
|
@ -77,7 +78,7 @@ def clean_model_doc_toc(model_doc):
|
|||
return sorted(new_doc, key=lambda s: s["title"].lower())
|
||||
|
||||
|
||||
def check_model_doc(overwrite=False):
|
||||
def check_model_doc(overwrite: bool = False):
|
||||
"""
|
||||
Check that the content of the table of content in `_toctree.yml` is clean (no duplicates and sorted for the model
|
||||
API doc) and potentially auto-cleans it.
|
||||
|
|
|
@ -40,7 +40,16 @@ REPO_PATH = "."
|
|||
DOCTEST_FILE_PATHS = ["documentation_tests.txt", "slow_documentation_tests.txt"]
|
||||
|
||||
|
||||
def clean_doctest_list(doctest_file, overwrite=False):
|
||||
def clean_doctest_list(doctest_file: str, overwrite: bool = False):
|
||||
"""
|
||||
Cleans the doctest in a given file.
|
||||
|
||||
Args:
|
||||
doctest_file (`str`):
|
||||
The path to the doctest file to check or clean.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to fix problems. If `False`, will error when the file is not clean.
|
||||
"""
|
||||
non_existent_paths = []
|
||||
all_paths = []
|
||||
with open(doctest_file, "r", encoding="utf-8") as f:
|
||||
|
|
|
@ -12,10 +12,31 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script is responsible for making sure the dummies in utils/dummies_xxx.py are up to date with the main init.
|
||||
|
||||
Why dummies? This is to make sure that a user can always import all objects from `transformers`, even if they don't
|
||||
have the necessary extra libs installed. Those objects will then raise helpful error message whenever the user tries
|
||||
to access one of their methods.
|
||||
|
||||
Usage (from the root of the repo):
|
||||
|
||||
Check that the dummy files are up to date (used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py
|
||||
```
|
||||
|
||||
Update the dummy files if needed (used in `make fix-copies`):
|
||||
|
||||
```bash
|
||||
python utils/check_dummies.py --fix_and_overwrite
|
||||
```
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
# All paths are set with the intent you should run this script from the root of the repo with the command
|
||||
|
@ -26,13 +47,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
|
|||
_re_backend = re.compile(r"is\_([a-z_]*)_available()")
|
||||
# Matches from xxx import bla
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
# Matches if not is_xxx_available()
|
||||
_re_test_backend = re.compile(r"^\s+if\s+not\s+\(?is\_[a-z_]*\_available\(\)")
|
||||
|
||||
|
||||
# Template for the dummy objects.
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_CLASS = """
|
||||
class {0}(metaclass=DummyObject):
|
||||
_backends = {1}
|
||||
|
@ -48,8 +72,18 @@ def {0}(*args, **kwargs):
|
|||
"""
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line in an init file.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
|
@ -57,8 +91,13 @@ def find_backend(line):
|
|||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def read_init():
|
||||
"""Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects."""
|
||||
def read_init() -> Dict[str, List[str]]:
|
||||
"""
|
||||
Read the init and extract backend-specific objects.
|
||||
|
||||
Returns:
|
||||
Dict[str, List[str]]: A dictionary mapping backend name to the list of object names requiring that backend.
|
||||
"""
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
|
@ -83,8 +122,10 @@ def read_init():
|
|||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
# Single-line imports
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
# Multiple-line imports (with 3 indent level)
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
|
@ -95,8 +136,17 @@ def read_init():
|
|||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name, backend_name):
|
||||
"""Create the code for the dummy object corresponding to `name`."""
|
||||
def create_dummy_object(name: str, backend_name: str) -> str:
|
||||
"""
|
||||
Create the code for a dummy object.
|
||||
|
||||
Args:
|
||||
name (`str`): The name of the object.
|
||||
backend_name (`str`): The name of the backend required for that object.
|
||||
|
||||
Returns:
|
||||
`str`: The code of the dummy object.
|
||||
"""
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
|
@ -105,11 +155,21 @@ def create_dummy_object(name, backend_name):
|
|||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files(backend_specific_objects=None):
|
||||
"""Create the content of the dummy files."""
|
||||
def create_dummy_files(backend_specific_objects: Optional[Dict[str, List[str]]] = None) -> Dict[str, str]:
|
||||
"""
|
||||
Create the content of the dummy files.
|
||||
|
||||
Args:
|
||||
backend_specific_objects (`Dict[str, List[str]]`, *optional*):
|
||||
The mapping backend name to list of backend-specific objects. If not passed, will be obtained by calling
|
||||
`read_init()`.
|
||||
|
||||
Returns:
|
||||
`Dict[str, str]`: A dictionary mapping backend name to code of the corresponding backend file.
|
||||
"""
|
||||
if backend_specific_objects is None:
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
|
||||
dummy_files = {}
|
||||
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
|
@ -122,10 +182,17 @@ def create_dummy_files(backend_specific_objects=None):
|
|||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
"""Check if the dummy files are up to date and maybe `overwrite` with the right content."""
|
||||
def check_dummies(overwrite: bool = False):
|
||||
"""
|
||||
Check if the dummy files are up to date and maybe `overwrite` with the right content.
|
||||
|
||||
Args:
|
||||
overwrite (`bool`, *optional*, default to `False`):
|
||||
Whether or not to overwrite the content of the dummy files. Will raise an error if they are not up to date
|
||||
when `overwrite=False`.
|
||||
"""
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||
# For special correspondence backend name to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
|
@ -143,6 +210,7 @@ def check_dummies(overwrite=False):
|
|||
else:
|
||||
actual_dummies[backend] = ""
|
||||
|
||||
# Compare actual with what they should be.
|
||||
for backend in dummy_files.keys():
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
|
|
|
@ -12,13 +12,37 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that checks the custom inits of Transformers are well-defined: Transformers uses init files that delay the
|
||||
import of an object to when it's actually needed. This is to avoid the main init importing all models, which would
|
||||
make the line `import transformers` very slow when the user has all optional dependencies installed. The inits with
|
||||
delayed imports have two halves: one definining a dictionary `_import_structure` which maps modules to the name of the
|
||||
objects in each module, and one in `TYPE_CHECKING` which looks like a normal init for type-checkers. The goal of this
|
||||
script is to check the objects defined in both halves are the same.
|
||||
|
||||
This also checks the main init properly references all submodules, even if it doesn't import anything from them: every
|
||||
submodule should be defined as a key of `_import_structure`, with an empty list as value potentially, or the submodule
|
||||
won't be importable.
|
||||
|
||||
Use from the root of the repo with:
|
||||
|
||||
```bash
|
||||
python utils/check_inits.py
|
||||
```
|
||||
|
||||
for a check that will error in case of inconsistencies (used by `make repo-consistency`).
|
||||
|
||||
There is no auto-fix possible here sadly :-(
|
||||
"""
|
||||
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
# Path is set with the intent you should run this script from the root of the repo.
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
|
||||
|
@ -46,8 +70,18 @@ _re_try = re.compile(r"^\s*try:")
|
|||
_re_else = re.compile(r"^\s*else:")
|
||||
|
||||
|
||||
def find_backend(line):
|
||||
"""Find one (or multiple) backend in a code line of the init."""
|
||||
def find_backend(line: str) -> Optional[str]:
|
||||
"""
|
||||
Find one (or multiple) backend in a code line of the init.
|
||||
|
||||
Args:
|
||||
line (`str`): A code line of the main init.
|
||||
|
||||
Returns:
|
||||
Optional[`str`]: If one (or several) backend is found, returns it. In the case of multiple backends (the line
|
||||
contains `if is_xxx_available() and `is_yyy_available()`) returns all backends joined on `_and_` (so
|
||||
`xxx_and_yyy` for instance).
|
||||
"""
|
||||
if _re_test_backend.search(line) is None:
|
||||
return None
|
||||
backends = [b[0] for b in _re_backend.findall(line)]
|
||||
|
@ -55,14 +89,23 @@ def find_backend(line):
|
|||
return "_and_".join(backends)
|
||||
|
||||
|
||||
def parse_init(init_file):
|
||||
def parse_init(init_file) -> Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]:
|
||||
"""
|
||||
Read an init_file and parse (per backend) the _import_structure objects defined and the TYPE_CHECKING objects
|
||||
defined
|
||||
Read an init_file and parse (per backend) the `_import_structure` objects defined and the `TYPE_CHECKING` objects
|
||||
defined.
|
||||
|
||||
Args:
|
||||
init_file (`str`): Path to the init file to inspect.
|
||||
|
||||
Returns:
|
||||
`Optional[Tuple[Dict[str, List[str]], Dict[str, List[str]]]]`: A tuple of two dictionaries mapping backends to list of
|
||||
imported objects, one for the `_import_structure` part of the init and one for the `TYPE_CHECKING` part of the
|
||||
init. Returns `None` if the init is not a custom init.
|
||||
"""
|
||||
with open(init_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get the to `_import_structure` definition.
|
||||
line_index = 0
|
||||
while line_index < len(lines) and not lines[line_index].startswith("_import_structure = {"):
|
||||
line_index += 1
|
||||
|
@ -91,7 +134,9 @@ def parse_init(init_file):
|
|||
objects.append(line[9:-3])
|
||||
line_index += 1
|
||||
|
||||
# Those are stored with the key "none".
|
||||
import_dict_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects in _import_structure
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
# If the line is an if not is_backend_available, we grab all objects associated.
|
||||
|
@ -151,6 +196,7 @@ def parse_init(init_file):
|
|||
line_index += 1
|
||||
|
||||
type_hint_objects = {"none": objects}
|
||||
|
||||
# Let's continue with backend-specific objects
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backend_available, we grab all objects associated.
|
||||
|
@ -186,19 +232,33 @@ def parse_init(init_file):
|
|||
return import_dict_objects, type_hint_objects
|
||||
|
||||
|
||||
def analyze_results(import_dict_objects, type_hint_objects):
|
||||
def analyze_results(import_dict_objects: Dict[str, List[str]], type_hint_objects: Dict[str, List[str]]) -> List[str]:
|
||||
"""
|
||||
Analyze the differences between _import_structure objects and TYPE_CHECKING objects found in an init.
|
||||
|
||||
Args:
|
||||
import_dict_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
type_hint_objects (`Dict[str, List[str]]`):
|
||||
A dictionary mapping backend names (`"none"` for the objects independent of any specific backend) to
|
||||
list of imported objects.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of errors corresponding to mismatches.
|
||||
"""
|
||||
|
||||
def find_duplicates(seq):
|
||||
return [k for k, v in collections.Counter(seq).items() if v > 1]
|
||||
|
||||
# If one backend is missing from the other part of the init, error early.
|
||||
if list(import_dict_objects.keys()) != list(type_hint_objects.keys()):
|
||||
return ["Both sides of the init do not have the same backends!"]
|
||||
|
||||
errors = []
|
||||
# Find all errors.
|
||||
for key in import_dict_objects.keys():
|
||||
# Duplicate imports in any half.
|
||||
duplicate_imports = find_duplicates(import_dict_objects[key])
|
||||
if duplicate_imports:
|
||||
errors.append(f"Duplicate _import_structure definitions for: {duplicate_imports}")
|
||||
|
@ -206,6 +266,7 @@ def analyze_results(import_dict_objects, type_hint_objects):
|
|||
if duplicate_type_hints:
|
||||
errors.append(f"Duplicate TYPE_CHECKING objects for: {duplicate_type_hints}")
|
||||
|
||||
# Missing imports in either part of the init.
|
||||
if sorted(set(import_dict_objects[key])) != sorted(set(type_hint_objects[key])):
|
||||
name = "base imports" if key == "none" else f"{key} backend"
|
||||
errors.append(f"Differences for {name}:")
|
||||
|
@ -237,7 +298,7 @@ def check_all_inits():
|
|||
raise ValueError("\n\n".join(failures))
|
||||
|
||||
|
||||
def get_transformers_submodules():
|
||||
def get_transformers_submodules() -> List[str]:
|
||||
"""
|
||||
Returns the list of Transformers submodules.
|
||||
"""
|
||||
|
@ -272,6 +333,9 @@ IGNORE_SUBMODULES = [
|
|||
|
||||
|
||||
def check_submodules():
|
||||
"""
|
||||
Check all submodules of Transformers are properly registered in the main init. Error otherwise.
|
||||
"""
|
||||
# This is to make sure the transformers module imported is the one in the repo.
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
|
|
@ -12,15 +12,34 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utility that performs several consistency checks on the repo. This includes:
|
||||
- checking all models are properly defined in the __init__ of models/
|
||||
- checking all models are in the main __init__
|
||||
- checking all models are properly tested
|
||||
- checking all object in the main __init__ are documented
|
||||
- checking all models are in at least one auto class
|
||||
- checking all the auto mapping are properly defined (no typos, importable)
|
||||
- checking the list of deprecated models is up to date
|
||||
|
||||
Use from the root of the repo with (as used in `make repo-consistency`):
|
||||
|
||||
```bash
|
||||
python utils/check_repo.py
|
||||
```
|
||||
|
||||
It has no auto-fix mode.
|
||||
"""
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import types
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from difflib import get_close_matches
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_flax_available, is_tf_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
|
@ -60,91 +79,25 @@ PRIVATE_MODELS = [
|
|||
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||
# models to ignore for not tested
|
||||
"InstructBlipQFormerModel", # Building part of bigger (tested) model.
|
||||
"NllbMoeDecoder",
|
||||
"NllbMoeEncoder",
|
||||
"UMT5EncoderModel", # Building part of bigger (tested) model.
|
||||
"LlamaDecoder", # Building part of bigger (tested) model.
|
||||
"Blip2QFormerModel", # Building part of bigger (tested) model.
|
||||
"DetaEncoder", # Building part of bigger (tested) model.
|
||||
"DetaDecoder", # Building part of bigger (tested) model.
|
||||
"ErnieMForInformationExtraction",
|
||||
"GraphormerEncoder", # Building part of bigger (tested) model.
|
||||
"GraphormerDecoderHead", # Building part of bigger (tested) model.
|
||||
"CLIPSegDecoder", # Building part of bigger (tested) model.
|
||||
"TableTransformerEncoder", # Building part of bigger (tested) model.
|
||||
"TableTransformerDecoder", # Building part of bigger (tested) model.
|
||||
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
|
||||
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
|
||||
"InformerEncoder", # Building part of bigger (tested) model.
|
||||
"InformerDecoder", # Building part of bigger (tested) model.
|
||||
"AutoformerEncoder", # Building part of bigger (tested) model.
|
||||
"AutoformerDecoder", # Building part of bigger (tested) model.
|
||||
"JukeboxVQVAE", # Building part of bigger (tested) model.
|
||||
"JukeboxPrior", # Building part of bigger (tested) model.
|
||||
"DeformableDetrEncoder", # Building part of bigger (tested) model.
|
||||
"DeformableDetrDecoder", # Building part of bigger (tested) model.
|
||||
"OPTDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"FlaxWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"WhisperDecoder", # Building part of bigger (tested) model.
|
||||
"WhisperEncoder", # Building part of bigger (tested) model.
|
||||
"DecisionTransformerGPT2Model", # Building part of bigger (tested) model.
|
||||
"SegformerDecodeHead", # Building part of bigger (tested) model.
|
||||
"PLBartEncoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoder", # Building part of bigger (tested) model.
|
||||
"PLBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||
"BigBirdPegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DetrEncoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoder", # Building part of bigger (tested) model.
|
||||
"DetrDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"ConditionalDetrEncoder", # Building part of bigger (tested) model.
|
||||
"ConditionalDetrDecoder", # Building part of bigger (tested) model.
|
||||
"M2M100Encoder", # Building part of bigger (tested) model.
|
||||
"M2M100Decoder", # Building part of bigger (tested) model.
|
||||
"MCTCTEncoder", # Building part of bigger (tested) model.
|
||||
"MgpstrModel", # Building part of bigger (tested) model.
|
||||
"Speech2TextEncoder", # Building part of bigger (tested) model.
|
||||
"Speech2TextDecoder", # Building part of bigger (tested) model.
|
||||
"LEDEncoder", # Building part of bigger (tested) model.
|
||||
"LEDDecoder", # Building part of bigger (tested) model.
|
||||
"BartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BartEncoder", # Building part of bigger (tested) model.
|
||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||
"BlenderbotSmallEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"BlenderbotEncoder", # Building part of bigger (tested) model.
|
||||
"BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MBartEncoder", # Building part of bigger (tested) model.
|
||||
"MBartDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MegatronBertLMHeadModel", # Building part of bigger (tested) model.
|
||||
"MegatronBertEncoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoder", # Building part of bigger (tested) model.
|
||||
"MegatronBertDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MusicgenDecoder", # Building part of bigger (tested) model.
|
||||
"MvpDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"MvpEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"PegasusXEncoder", # Building part of bigger (tested) model.
|
||||
"PegasusXDecoder", # Building part of bigger (tested) model.
|
||||
"PegasusXDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"DPREncoder", # Building part of bigger (tested) model.
|
||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"RealmBertModel", # Building part of bigger (tested) model.
|
||||
"RealmReader", # Not regular model.
|
||||
"RealmScorer", # Not regular model.
|
||||
"RealmForOpenQA", # Not regular model.
|
||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||
"Speech2Text2DecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||
"TFRobertaForMultipleChoice", # TODO: fix
|
||||
"TFRobertaPreLayerNormForMultipleChoice", # TODO: fix
|
||||
"TrOCRDecoderWrapper", # Building part of bigger (tested) model.
|
||||
"TFWhisperEncoder", # Building part of bigger (tested) model.
|
||||
"TFWhisperDecoder", # Building part of bigger (tested) model.
|
||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||
"FlaxBartForCausalLM", # Building part of bigger (tested) model.
|
||||
"FlaxBertForCausalLM", # Building part of bigger (tested) model. Tested implicitly through FlaxRobertaForCausalLM.
|
||||
|
@ -155,18 +108,6 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
|||
"TFBlipTextLMHeadModel", # No need to test it as it is tested by BlipTextVision models
|
||||
"BridgeTowerTextModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"BridgeTowerVisionModel", # No need to test it as it is tested by BridgeTowerModel model.
|
||||
"SpeechT5Decoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithoutPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5DecoderWithTextPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5Encoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithoutPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithSpeechPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5EncoderWithTextPrenet", # Building part of bigger (tested) model.
|
||||
"SpeechT5SpeechDecoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5SpeechEncoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5TextDecoder", # Building part of bigger (tested) model.
|
||||
"SpeechT5TextEncoder", # Building part of bigger (tested) model.
|
||||
"BarkCausalModel", # Building part of bigger (tested) model.
|
||||
"BarkModel", # Does not have a forward signature - generation tested with integration tests
|
||||
]
|
||||
|
@ -236,12 +177,6 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||
"AutoformerForPrediction",
|
||||
"JukeboxVQVAE",
|
||||
"JukeboxPrior",
|
||||
"PegasusXEncoder",
|
||||
"PegasusXDecoder",
|
||||
"PegasusXDecoderWrapper",
|
||||
"PegasusXEncoder",
|
||||
"PegasusXDecoder",
|
||||
"PegasusXDecoderWrapper",
|
||||
"SamModel",
|
||||
"DPTForDepthEstimation",
|
||||
"DecisionTransformerGPT2Model",
|
||||
|
@ -250,17 +185,11 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||
"ViltForImageAndTextRetrieval",
|
||||
"ViltForTokenClassification",
|
||||
"ViltForMaskedLM",
|
||||
"XGLMEncoder",
|
||||
"XGLMDecoder",
|
||||
"XGLMDecoderWrapper",
|
||||
"PerceiverForMultimodalAutoencoding",
|
||||
"PerceiverForOpticalFlow",
|
||||
"SegformerDecodeHead",
|
||||
"TFSegformerDecodeHead",
|
||||
"FlaxBeitForMaskedImageModeling",
|
||||
"PLBartEncoder",
|
||||
"PLBartDecoder",
|
||||
"PLBartDecoderWrapper",
|
||||
"BeitForMaskedImageModeling",
|
||||
"ChineseCLIPTextModel",
|
||||
"ChineseCLIPVisionModel",
|
||||
|
@ -347,7 +276,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
|||
]
|
||||
|
||||
# DO NOT edit this list!
|
||||
# (The corresponding pytorch objects should never be in the main `__init__`, but it's too late to remove)
|
||||
# (The corresponding pytorch objects should never have been in the main `__init__`, but it's too late to remove)
|
||||
OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
||||
"FlaxBertLayer",
|
||||
"FlaxBigBirdLayer",
|
||||
|
@ -361,8 +290,7 @@ OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK = [
|
|||
"TFViTMAELayer",
|
||||
]
|
||||
|
||||
# Update this list for models that have multiple model types for the same
|
||||
# model doc
|
||||
# Update this list for models that have multiple model types for the same model doc.
|
||||
MODEL_TYPE_TO_DOC_MAPPING = OrderedDict(
|
||||
[
|
||||
("data2vec-text", "data2vec"),
|
||||
|
@ -378,6 +306,10 @@ transformers = direct_transformers_import(PATH_TO_TRANSFORMERS)
|
|||
|
||||
|
||||
def check_missing_backends():
|
||||
"""
|
||||
Checks if all backends are installed (otherwise the check of this script is incomplete). Will error in the CI if
|
||||
that's not the case but only throw a warning for users running this.
|
||||
"""
|
||||
missing_backends = []
|
||||
if not is_torch_available():
|
||||
missing_backends.append("PyTorch")
|
||||
|
@ -402,7 +334,9 @@ def check_missing_backends():
|
|||
|
||||
|
||||
def check_model_list():
|
||||
"""Check the model list inside the transformers library."""
|
||||
"""
|
||||
Checks the model listed as subfolders of `models` match the models available in `transformers.models`.
|
||||
"""
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
models_dir = os.path.join(PATH_TO_TRANSFORMERS, "models")
|
||||
_models = []
|
||||
|
@ -413,7 +347,7 @@ def check_model_list():
|
|||
if os.path.isdir(model_dir) and "__init__.py" in os.listdir(model_dir):
|
||||
_models.append(model)
|
||||
|
||||
# Get the models from the directory structure of `src/transformers/models/`
|
||||
# Get the models in the submodule `transformers.models`
|
||||
models = [model for model in dir(transformers.models) if not model.startswith("__")]
|
||||
|
||||
missing_models = sorted(set(_models).difference(models))
|
||||
|
@ -425,8 +359,8 @@ def check_model_list():
|
|||
|
||||
# If some modeling modules should be ignored for all checks, they should be added in the nested list
|
||||
# _ignore_modules of this function.
|
||||
def get_model_modules():
|
||||
"""Get the model modules inside the transformers library."""
|
||||
def get_model_modules() -> List[str]:
|
||||
"""Get all the model modules inside the transformers library (except deprecated models)."""
|
||||
_ignore_modules = [
|
||||
"modeling_auto",
|
||||
"modeling_encoder_decoder",
|
||||
|
@ -454,21 +388,32 @@ def get_model_modules():
|
|||
]
|
||||
modules = []
|
||||
for model in dir(transformers.models):
|
||||
if model == "deprecated":
|
||||
continue
|
||||
# There are some magic dunder attributes in the dir, we ignore them
|
||||
if not model.startswith("__"):
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
if model == "deprecated" or model.startswith("__"):
|
||||
continue
|
||||
|
||||
model_module = getattr(transformers.models, model)
|
||||
for submodule in dir(model_module):
|
||||
if submodule.startswith("modeling") and submodule not in _ignore_modules:
|
||||
modeling_module = getattr(model_module, submodule)
|
||||
if inspect.ismodule(modeling_module):
|
||||
modules.append(modeling_module)
|
||||
return modules
|
||||
|
||||
|
||||
def get_models(module, include_pretrained=False):
|
||||
"""Get the objects in module that are models."""
|
||||
def get_models(module: types.ModuleType, include_pretrained: bool = False) -> List[Tuple[str, type]]:
|
||||
"""
|
||||
Get the objects in a module that are models.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`):
|
||||
The module from which we are extracting models.
|
||||
include_pretrained (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the `PreTrainedModel` subclass (like `BertPreTrainedModel`) or not.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, type]]: List of models as tuples (class name, actual class).
|
||||
"""
|
||||
models = []
|
||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||
for attr_name in dir(module):
|
||||
|
@ -480,12 +425,10 @@ def get_models(module, include_pretrained=False):
|
|||
return models
|
||||
|
||||
|
||||
def is_a_private_model(model):
|
||||
"""Returns True if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
|
||||
# Wrapper, Encoder and Decoder are all privates
|
||||
def is_building_block(model: str) -> bool:
|
||||
"""
|
||||
Returns `True` if a model is a building block part of a bigger model.
|
||||
"""
|
||||
if model.endswith("Wrapper"):
|
||||
return True
|
||||
if model.endswith("Encoder"):
|
||||
|
@ -494,7 +437,13 @@ def is_a_private_model(model):
|
|||
return True
|
||||
if model.endswith("Prenet"):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_a_private_model(model: str) -> bool:
|
||||
"""Returns `True` if the model should not be in the main init."""
|
||||
if model in PRIVATE_MODELS:
|
||||
return True
|
||||
return is_building_block(model)
|
||||
|
||||
|
||||
def check_models_are_in_init():
|
||||
|
@ -514,11 +463,14 @@ def check_models_are_in_init():
|
|||
|
||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||
# nested list _ignore_files of this function.
|
||||
def get_model_test_files():
|
||||
"""Get the model test files.
|
||||
def get_model_test_files() -> List[str]:
|
||||
"""
|
||||
Get the model test files.
|
||||
|
||||
The returned files should NOT contain the `tests` (i.e. `PATH_TO_TESTS` defined in this script). They will be
|
||||
considered as paths relative to `tests`. A caller has to use `os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
Returns:
|
||||
`List[str]`: The list of test files. The returned files will NOT contain the `tests` (i.e. `PATH_TO_TESTS`
|
||||
defined in this script). They will be considered as paths relative to `tests`. A caller has to use
|
||||
`os.path.join(PATH_TO_TESTS, ...)` to access the files.
|
||||
"""
|
||||
|
||||
_ignore_files = [
|
||||
|
@ -531,7 +483,6 @@ def get_model_test_files():
|
|||
"test_modeling_tf_encoder_decoder",
|
||||
]
|
||||
test_files = []
|
||||
# Check both `PATH_TO_TESTS` and `PATH_TO_TESTS/models`
|
||||
model_test_root = os.path.join(PATH_TO_TESTS, "models")
|
||||
model_test_dirs = []
|
||||
for x in os.listdir(model_test_root):
|
||||
|
@ -553,9 +504,17 @@ def get_model_test_files():
|
|||
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the tester class
|
||||
# for the all_model_classes variable.
|
||||
def find_tested_models(test_file):
|
||||
"""Parse the content of test_file to detect what's in all_model_classes"""
|
||||
# This is a bit hacky but I didn't find a way to import the test_file as a module and read inside the class
|
||||
def find_tested_models(test_file: str) -> List[str]:
|
||||
"""
|
||||
Parse the content of test_file to detect what's in `all_model_classes`. This detects the models that inherit from
|
||||
the common test class.
|
||||
|
||||
Args:
|
||||
test_file (`str`): The path to the test file to check
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of models tested in that file.
|
||||
"""
|
||||
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
|
||||
content = f.read()
|
||||
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
|
||||
|
@ -571,8 +530,25 @@ def find_tested_models(test_file):
|
|||
return model_tested
|
||||
|
||||
|
||||
def check_models_are_tested(module, test_file):
|
||||
"""Check models defined in module are tested in test_file."""
|
||||
def should_be_tested(model_name: str) -> bool:
|
||||
"""
|
||||
Whether or not a model should be tested.
|
||||
"""
|
||||
if model_name in IGNORE_NON_TESTED:
|
||||
return False
|
||||
return not is_building_block(model_name)
|
||||
|
||||
|
||||
def check_models_are_tested(module: types.ModuleType, test_file: str) -> List[str]:
|
||||
"""Check models defined in a module are all tested in a given file.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`): The module in which we get the models.
|
||||
test_file (`str`): The path to the file where the module is tested.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of error messages corresponding to models not tested.
|
||||
"""
|
||||
# XxxPreTrainedModel are not tested
|
||||
defined_models = get_models(module)
|
||||
tested_models = find_tested_models(test_file)
|
||||
|
@ -586,7 +562,7 @@ def check_models_are_tested(module, test_file):
|
|||
]
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
if model_name not in tested_models and model_name not in IGNORE_NON_TESTED:
|
||||
if model_name not in tested_models and should_be_tested(model_name):
|
||||
failures.append(
|
||||
f"{model_name} is defined in {module.__name__} but is not tested in "
|
||||
+ f"{os.path.join(PATH_TO_TESTS, test_file)}. Add it to the all_model_classes in that file."
|
||||
|
@ -602,6 +578,7 @@ def check_all_models_are_tested():
|
|||
test_files = get_model_test_files()
|
||||
failures = []
|
||||
for module in modules:
|
||||
# Matches a module to its test file.
|
||||
test_file = [file for file in test_files if f"test_{module.__name__.split('.')[-1]}.py" in file]
|
||||
if len(test_file) == 0:
|
||||
failures.append(f"{module.__name__} does not have its corresponding test file {test_file}.")
|
||||
|
@ -616,7 +593,7 @@ def check_all_models_are_tested():
|
|||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
||||
def get_all_auto_configured_models():
|
||||
def get_all_auto_configured_models() -> List[str]:
|
||||
"""Return the list of all models in at least one auto class."""
|
||||
result = set() # To avoid duplicates we concatenate all model classes in a set.
|
||||
if is_torch_available():
|
||||
|
@ -634,8 +611,8 @@ def get_all_auto_configured_models():
|
|||
return list(result)
|
||||
|
||||
|
||||
def ignore_unautoclassed(model_name):
|
||||
"""Rules to determine if `name` should be in an auto class."""
|
||||
def ignore_unautoclassed(model_name: str) -> bool:
|
||||
"""Rules to determine if a model should be in an auto class."""
|
||||
# Special white list
|
||||
if model_name in IGNORE_NON_AUTO_CONFIGURED:
|
||||
return True
|
||||
|
@ -645,8 +622,19 @@ def ignore_unautoclassed(model_name):
|
|||
return False
|
||||
|
||||
|
||||
def check_models_are_auto_configured(module, all_auto_models):
|
||||
"""Check models defined in module are each in an auto class."""
|
||||
def check_models_are_auto_configured(module: types.ModuleType, all_auto_models: List[str]) -> List[str]:
|
||||
"""
|
||||
Check models defined in module are each in an auto class.
|
||||
|
||||
Args:
|
||||
module (`types.ModuleType`):
|
||||
The module in which we get the models.
|
||||
all_auto_models (`List[str]`):
|
||||
The list of all models in an auto class (as obtained with `get_all_auto_configured_models()`).
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of error messages corresponding to models not tested.
|
||||
"""
|
||||
defined_models = get_models(module)
|
||||
failures = []
|
||||
for model_name, _ in defined_models:
|
||||
|
@ -661,6 +649,7 @@ def check_models_are_auto_configured(module, all_auto_models):
|
|||
|
||||
def check_all_models_are_auto_configured():
|
||||
"""Check all models are each in an auto class."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
modules = get_model_modules()
|
||||
all_auto_models = get_all_auto_configured_models()
|
||||
|
@ -675,6 +664,7 @@ def check_all_models_are_auto_configured():
|
|||
|
||||
def check_all_auto_object_names_being_defined():
|
||||
"""Check all names defined in auto (name) mappings exist in the library."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
|
@ -695,7 +685,7 @@ def check_all_auto_object_names_being_defined():
|
|||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, mapping in mappings_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
for _, class_names in mapping.items():
|
||||
if not isinstance(class_names, tuple):
|
||||
class_names = (class_names,)
|
||||
for class_name in class_names:
|
||||
|
@ -716,6 +706,7 @@ def check_all_auto_object_names_being_defined():
|
|||
|
||||
def check_all_auto_mapping_names_in_config_mapping_names():
|
||||
"""Check all keys defined in auto mappings (mappings of names) appear in `CONFIG_MAPPING_NAMES`."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
|
@ -736,7 +727,7 @@ def check_all_auto_mapping_names_in_config_mapping_names():
|
|||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, mapping in mappings_to_check.items():
|
||||
for model_type, class_names in mapping.items():
|
||||
for model_type in mapping:
|
||||
if model_type not in CONFIG_MAPPING_NAMES:
|
||||
failures.append(
|
||||
f"`{model_type}` appears in the mapping `{name}` but it is not defined in the keys of "
|
||||
|
@ -747,7 +738,8 @@ def check_all_auto_mapping_names_in_config_mapping_names():
|
|||
|
||||
|
||||
def check_all_auto_mappings_importable():
|
||||
"""Check all auto mappings could be imported."""
|
||||
"""Check all auto mappings can be imported."""
|
||||
# This is where we need to check we have all backends or the check is incomplete.
|
||||
check_missing_backends()
|
||||
|
||||
failures = []
|
||||
|
@ -761,7 +753,7 @@ def check_all_auto_mappings_importable():
|
|||
mapping_names = [x for x in dir(module) if x.endswith("_MAPPING_NAMES")]
|
||||
mappings_to_check.update({name: getattr(module, name) for name in mapping_names})
|
||||
|
||||
for name, _ in mappings_to_check.items():
|
||||
for name in mappings_to_check:
|
||||
name = name.replace("_MAPPING_NAMES", "_MAPPING")
|
||||
if not hasattr(transformers, name):
|
||||
failures.append(f"`{name}`")
|
||||
|
@ -770,44 +762,46 @@ def check_all_auto_mappings_importable():
|
|||
|
||||
|
||||
def check_objects_being_equally_in_main_init():
|
||||
"""Check if an object is in the main __init__ if its counterpart in PyTorch is."""
|
||||
"""
|
||||
Check if a (TensorFlow or Flax) object is in the main __init__ iif its counterpart in PyTorch is.
|
||||
"""
|
||||
attrs = dir(transformers)
|
||||
|
||||
failures = []
|
||||
for attr in attrs:
|
||||
obj = getattr(transformers, attr)
|
||||
if hasattr(obj, "__module__"):
|
||||
module_path = obj.__module__
|
||||
if "models.deprecated" in module_path:
|
||||
continue
|
||||
module_name = module_path.split(".")[-1]
|
||||
module_dir = ".".join(module_path.split(".")[:-1])
|
||||
if (
|
||||
module_name.startswith("modeling_")
|
||||
and not module_name.startswith("modeling_tf_")
|
||||
and not module_name.startswith("modeling_flax_")
|
||||
):
|
||||
parent_module = sys.modules[module_dir]
|
||||
if not hasattr(obj, "__module__") or "models.deprecated" in obj.__module__:
|
||||
continue
|
||||
|
||||
frameworks = []
|
||||
if is_tf_available():
|
||||
frameworks.append("TF")
|
||||
if is_flax_available():
|
||||
frameworks.append("Flax")
|
||||
module_path = obj.__module__
|
||||
module_name = module_path.split(".")[-1]
|
||||
module_dir = ".".join(module_path.split(".")[:-1])
|
||||
if (
|
||||
module_name.startswith("modeling_")
|
||||
and not module_name.startswith("modeling_tf_")
|
||||
and not module_name.startswith("modeling_flax_")
|
||||
):
|
||||
parent_module = sys.modules[module_dir]
|
||||
|
||||
for framework in frameworks:
|
||||
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
||||
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
other_module = getattr(parent_module, other_module_name)
|
||||
if hasattr(other_module, f"{framework}{attr}"):
|
||||
if not hasattr(transformers, f"{framework}{attr}"):
|
||||
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}{attr}")
|
||||
if hasattr(other_module, f"{framework}_{attr}"):
|
||||
if not hasattr(transformers, f"{framework}_{attr}"):
|
||||
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}_{attr}")
|
||||
frameworks = []
|
||||
if is_tf_available():
|
||||
frameworks.append("TF")
|
||||
if is_flax_available():
|
||||
frameworks.append("Flax")
|
||||
|
||||
for framework in frameworks:
|
||||
other_module_path = module_path.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
if os.path.isfile("src/" + other_module_path.replace(".", "/") + ".py"):
|
||||
other_module_name = module_name.replace("modeling_", f"modeling_{framework.lower()}_")
|
||||
other_module = getattr(parent_module, other_module_name)
|
||||
if hasattr(other_module, f"{framework}{attr}"):
|
||||
if not hasattr(transformers, f"{framework}{attr}"):
|
||||
if f"{framework}{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}{attr}")
|
||||
if hasattr(other_module, f"{framework}_{attr}"):
|
||||
if not hasattr(transformers, f"{framework}_{attr}"):
|
||||
if f"{framework}_{attr}" not in OBJECT_TO_SKIP_IN_MAIN_INIT_CHECK:
|
||||
failures.append(f"{framework}_{attr}")
|
||||
if len(failures) > 0:
|
||||
raise Exception(f"There were {len(failures)} failures:\n" + "\n".join(failures))
|
||||
|
||||
|
@ -815,8 +809,16 @@ def check_objects_being_equally_in_main_init():
|
|||
_re_decorator = re.compile(r"^\s*@(\S+)\s+$")
|
||||
|
||||
|
||||
def check_decorator_order(filename):
|
||||
"""Check that in the test file `filename` the slow decorator is always last."""
|
||||
def check_decorator_order(filename: str) -> List[int]:
|
||||
"""
|
||||
Check that in a given test file, the slow decorator is always last.
|
||||
|
||||
Args:
|
||||
filename (`str`): The path to a test file to check.
|
||||
|
||||
Returns:
|
||||
`List[int]`: The list of failures as a list of indices where there are problems.
|
||||
"""
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
decorator_before = None
|
||||
|
@ -849,8 +851,13 @@ def check_all_decorator_order():
|
|||
)
|
||||
|
||||
|
||||
def find_all_documented_objects():
|
||||
"""Parse the content of all doc files to detect which classes and functions it documents"""
|
||||
def find_all_documented_objects() -> List[str]:
|
||||
"""
|
||||
Parse the content of all doc files to detect which classes and functions it documents.
|
||||
|
||||
Returns:
|
||||
`List[str]`: The list of all object names being documented.
|
||||
"""
|
||||
documented_obj = []
|
||||
for doc_file in Path(PATH_TO_DOC).glob("**/*.rst"):
|
||||
with open(doc_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
|
@ -959,8 +966,8 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
|||
]
|
||||
|
||||
|
||||
def ignore_undocumented(name):
|
||||
"""Rules to determine if `name` should be undocumented."""
|
||||
def ignore_undocumented(name: str) -> bool:
|
||||
"""Rules to determine if `name` should be undocumented (returns `True` if it should not be documented)."""
|
||||
# NOT DOCUMENTED ON PURPOSE.
|
||||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
|
@ -1047,7 +1054,7 @@ _re_double_backquotes = re.compile(r"(^|[^`])``([^`]+)``([^`]|$)")
|
|||
_re_rst_example = re.compile(r"^\s*Example.*::\s*$", flags=re.MULTILINE)
|
||||
|
||||
|
||||
def is_rst_docstring(docstring):
|
||||
def is_rst_docstring(docstring: str) -> True:
|
||||
"""
|
||||
Returns `True` if `docstring` is written in rst.
|
||||
"""
|
||||
|
@ -1061,7 +1068,7 @@ def is_rst_docstring(docstring):
|
|||
|
||||
|
||||
def check_docstrings_are_in_md():
|
||||
"""Check all docstrings are in md"""
|
||||
"""Check all docstrings are written in md and nor rst."""
|
||||
files_with_rst = []
|
||||
for file in Path(PATH_TO_TRANSFORMERS).glob("**/*.py"):
|
||||
with open(file, encoding="utf-8") as f:
|
||||
|
@ -1084,6 +1091,9 @@ def check_docstrings_are_in_md():
|
|||
|
||||
|
||||
def check_deprecated_constant_is_up_to_date():
|
||||
"""
|
||||
Check if the constant `DEPRECATED_MODELS` in `models/auto/configuration_auto.py` is up to date.
|
||||
"""
|
||||
deprecated_folder = os.path.join(PATH_TO_TRANSFORMERS, "models", "deprecated")
|
||||
deprecated_models = [m for m in os.listdir(deprecated_folder) if not m.startswith("_")]
|
||||
|
||||
|
|
Loading…
Reference in New Issue