Allow `# Ignore copy` (#27328)

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-12-07 10:00:08 +01:00 committed by GitHub
parent 44b5506d29
commit 52746922b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 618 additions and 53 deletions

View File

@ -28,7 +28,9 @@ from ...test_tokenization_common import TokenizerTesterMixin
@require_tokenizers
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Ignore copy
tokenizer_class = LongformerTokenizer
test_slow_tokenizer = True
rust_tokenizer_class = LongformerTokenizerFast
@ -71,23 +73,19 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w", encoding="utf-8") as fp:
fp.write("\n".join(merges))
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer
def get_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer
def get_rust_tokenizer(self, **kwargs):
kwargs.update(self.special_tokens_map)
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts
def get_input_output_texts(self, tokenizer):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"
@ -99,7 +97,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer
def longformer_dict_integration_testing(self):
tokenizer = self.get_tokenizer()
@ -110,7 +107,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
@slow
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
@ -130,7 +126,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding
def test_space_encoding(self):
tokenizer = self.get_tokenizer()
@ -171,11 +166,9 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
self.assertNotEqual(first_char, space_encoding)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs
def test_pretokenized_inputs(self):
pass
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens
def test_embeded_special_tokens(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
@ -208,7 +201,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args
def test_change_add_prefix_space_and_trim_offsets_args(self):
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
@ -223,7 +215,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
# `trim_offsets`

View File

@ -95,13 +95,156 @@ class BertCopyModel(BertCopyPreTrainedModel):
"""
MOCK_DUMMY_BERT_CODE_MATCH = """
class BertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def existing_common(self, c):
return 4
def existing_diff_to_be_ignored(self, c):
return 9
"""
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def existing_common(self, c):
return 4
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
"""
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
class BertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_bert(self, c):
return 7
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 8
def existing_diff_to_be_ignored(self, c):
return 9
"""
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 3
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_roberta_not_ignored(self, c):
return 2
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 5
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
"""
EXPECTED_REPLACED_CODE = """
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
class RobertaBertDummyModel:
attr_1 = 1
attr_2 = 2
def __init__(self, a=1, b=2):
self.a = a
self.b = b
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
def forward(self, c):
return 1
def only_in_bert(self, c):
return 7
def existing_common(self, c):
return 4
def existing_diff_not_ignored(self, c):
return 8
# Ignore copy
def existing_diff_to_be_ignored(self, c):
return 6
# Ignore copy
def only_in_roberta_to_be_ignored(self, c):
return 3
"""
def replace_in_file(filename, old, new):
with open(filename, "r", encoding="utf-8") as f:
content = f.read()
content = content.replace(old, new)
with open(filename, "w", encoding="utf-8") as f:
with open(filename, "w", encoding="utf-8", newline="\n") as f:
f.write(content)
@ -117,11 +260,18 @@ def create_tmp_repo(tmp_dir):
model_dir = tmp_dir / "src" / "transformers" / "models"
model_dir.mkdir(parents=True, exist_ok=True)
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
models = {
"bert": MOCK_BERT_CODE,
"bertcopy": MOCK_BERT_COPY_CODE,
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
}
for model, code in models.items():
model_subdir = model_dir / model
model_subdir.mkdir(exist_ok=True)
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
f.write(code)
@ -176,11 +326,47 @@ class CopyCheckTester(unittest.TestCase):
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
diffs = is_copy_consistent(file_to_check, overwrite=True)
_ = is_copy_consistent(file_to_check, overwrite=True)
with open(file_to_check, "r", encoding="utf-8") as f:
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
def test_is_copy_consistent_with_ignored_match(self):
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
self.assertEqual(diffs, [])
def test_is_copy_consistent_with_ignored_no_match(self):
path_to_check = [
"src",
"transformers",
"models",
"dummy_roberta_no_match",
"modeling_dummy_roberta_no_match.py",
]
with tempfile.TemporaryDirectory() as tmp_folder:
# Base check with an inconsistency
create_tmp_repo(tmp_folder)
with patch_transformer_repo_path(tmp_folder):
file_to_check = os.path.join(tmp_folder, *path_to_check)
diffs = is_copy_consistent(file_to_check)
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
# (which has a leading `\n`.)
self.assertEqual(
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
)
_ = is_copy_consistent(file_to_check, overwrite=True)
with open(file_to_check, "r", encoding="utf-8") as f:
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
def test_convert_to_localized_md(self):
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]

View File

@ -41,7 +41,8 @@ import glob
import os
import re
import subprocess
from typing import List, Optional, Tuple
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
from transformers.utils import direct_transformers_import
@ -125,13 +126,213 @@ LOCALIZED_READMES = {
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
def _is_definition_header_ending_line(line: str) -> bool:
# Helper function. Returns `True` if `line` is the end parenthesis of a class/function definition
return re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
def _should_continue(line: str, indent: str) -> bool:
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
# function definition
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
# class/function definition
return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line)
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
def _sanity_check_splits(splits_1, splits_2, is_class):
"""Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match.
For the case of `class`, they must be of one of the following 3 cases:
- a single block without name:
class foo:
a = 1
- a consecutive sequence of (1 or more) blocks with name
class foo:
def f(x):
return x
- a block without name, followed by a consecutive sequence of (1 or more) blocks with name
class foo:
a = 1
def f(x):
return x
def g(x):
return None
The 2 code snippets that give `splits_1` and `splits_2` have to be in the same case to pass this check, but the
number of blocks with name in the consecutive sequence is not taken into account.
For the case of `function or method`, we don't require it to be in one of the above 3 cases. However, the structure
of`splits_1` and `splits_2` have to match exactly. In particular, the number of blocks with name in a consecutive
sequence is taken into account.
"""
block_names_1 = []
block_names_2 = []
for block in splits_1[1:]:
if block[0].startswith("_block_without_name_"):
block_names_1.append("block_without_name")
elif not block[0].startswith("_empty_block_") and (
not is_class or len(block_names_1) == 0 or block_names_1[-1].startswith("block_without_name")
):
block_names_1.append("block_with_name")
for block in splits_2[1:]:
if block[0].startswith("_block_without_name_"):
block_names_2.append("block_without_name")
elif not block[0].startswith("_empty_block_") and (
not is_class or len(block_names_2) == 0 or block_names_2[-1].startswith("block_without_name")
):
block_names_2.append("block_with_name")
if is_class:
if block_names_1 not in [
["block_without_name"],
["block_with_name"],
["block_without_name", "block_with_name"],
]:
raise ValueError(
"For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`"
)
if block_names_1 != block_names_2:
raise ValueError("The structures in the 2 code blocks differ.")
def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
"""
Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`).
Args:
lines (`List[str]`):
The source code, represented by a list of lines.
start_index (`int`):
The starting index of the target class/func block.
indent (`int`):
The indent of the class/func body.
Returns:
`int`: The index of the block's ending line plus by 1 (i.e. exclusive).
"""
indent = " " * indent
# enter the block body
line_index = start_index + 1
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
return line_index
def split_code_into_blocks(
lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
) -> List[Tuple[str, int, int]]:
"""
Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*.
The block's header is included as the first element. The contiguous regions (without empty lines) that are not
inside any inner block are included as blocks. The contiguous regions of empty lines that are not inside any inner
block are also included as (dummy) blocks.
Args:
lines (`List[str]`):
The source code, represented by a list of lines.
start_index (`int`):
The starting index of the target class/func block.
end_index (`int`):
The ending index of the target class/func block.
indent (`int`):
The indent of the class/func body.
backtrace (`bool`, *optional*, defaults to `False`):
Whether or not to include the lines before the inner class/func block's header (e.g. comments, decorators,
etc.) until an empty line is encountered.
Returns:
`List[Tuple[str, int, int]]`: A list of elements with the form `(block_name, start_index, end_index)`.
"""
splits = []
# `indent - 4` is the indent level of the target class/func header
target_block_name = re.search(rf"^{' ' * (indent - 4)}((class|def)\s+\S+)(\(|\:)", lines[start_index]).groups()[0]
# from now on, the `block` means inner blocks unless explicitly specified
indent_str = " " * indent
block_without_name_idx = 0
empty_block_idx = 0
# Find the lines for the definition header
index = start_index
if "(" in lines[start_index] and "):" not in lines[start_index] in lines[start_index]:
while index < end_index:
if _is_definition_header_ending_line(lines[index]):
break
index += 1
# the first line outside the definition header
index += 1
splits.append((target_block_name, start_index, index))
block_start_index, prev_block_end_index = index, index
while index < end_index:
# if found, it will be an inner block
block_found = re.search(rf"^{indent_str}((class|def)\s+\S+)(\(|\:)", lines[index])
if block_found:
name = block_found.groups()[0]
block_end_index = find_block_end(lines, index, indent + 4)
# backtrace to include the lines before the found block's definition header (e.g. comments, decorators,
# etc.) until an empty line is encountered.
block_start_index = index
if index > prev_block_end_index and backtrace:
idx = index - 1
for idx in range(index - 1, prev_block_end_index - 2, -1):
if not (len(lines[idx].strip()) > 0 and lines[idx].startswith(indent_str)):
break
idx += 1
if idx < index:
block_start_index = idx
# between the current found block and the previous found block
if block_start_index > prev_block_end_index:
# give it a dummy name
if len("".join(lines[prev_block_end_index:block_start_index]).strip()) == 0:
prev_block_name = f"_empty_block_{empty_block_idx}"
empty_block_idx += 1
else:
prev_block_name = f"_block_without_name_{block_without_name_idx}"
block_without_name_idx += 1
# Add it as a block
splits.append((prev_block_name, prev_block_end_index, block_start_index))
# Add the current found block
splits.append((name, block_start_index, block_end_index))
prev_block_end_index = block_end_index
index = block_end_index - 1
index += 1
if index > prev_block_end_index:
if len("".join(lines[prev_block_end_index:index]).strip()) == 0:
prev_block_name = f"_empty_block_{empty_block_idx}"
else:
prev_block_name = f"_block_without_name_{block_without_name_idx}"
splits.append((prev_block_name, prev_block_end_index, index))
return splits
def find_code_in_transformers(
object_name: str, base_path: str = None, return_indices: bool = False
) -> Union[str, Tuple[List[str], int, int]]:
"""
Find and return the source code of an object.
@ -140,9 +341,15 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
The name of the object we want the source code of.
base_path (`str`, *optional*):
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
return_indices(`bool`, *optional*, defaults to `False`):
If `False`, will only return the code (as a string), otherwise it will also return the whole lines of the
file where the object specified by `object_name` is defined, together the start/end indices of the block in
the file that defines the object.
Returns:
`str`: The source code of the object.
`Union[str, Tuple[List[str], int, int]]`: If `return_indices=False`, only the source code of the object will be
returned. Otherwise, it also returns the whole lines of the file where the object specified by `object_name` is
defined, together the start/end indices of the block in the file that defines the object.
"""
parts = object_name.split(".")
i = 0
@ -181,22 +388,91 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
):
line_index += 1
# find the target specified in the current level in `parts` -> increase `indent` so we can search the next
indent += " "
# the index of the first line in the (currently found) block *body*
line_index += 1
if line_index >= len(lines):
raise ValueError(f" {object_name} does not match any function or class in {module}.")
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
start_index = line_index - 1
while line_index < len(lines) and _should_continue(lines[line_index], indent):
line_index += 1
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
# `indent` is already one level deeper than the (found) class/func block's definition header
code_lines = lines[start_index:line_index]
return "".join(code_lines)
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
# `start_index` is the index of the class/func block's definition header
start_index = line_index - 1
end_index = find_block_end(lines, start_index, len(indent))
code = "".join(lines[start_index:end_index])
return (code, (lines, start_index, end_index)) if return_indices else code
def replace_code(code: str, replace_pattern: str) -> str:
"""Replace `code` by a pattern of the form `with X1->X2,Y1->Y2,Z1->Z2`.
Args:
code (`str`): The code to be modified.
replace_pattern (`str`): The pattern used to modify `code`.
Returns:
`str`: The modified code.
"""
if len(replace_pattern) > 0:
patterns = replace_pattern.replace("with", "").split(",")
patterns = [_re_replace_pattern.search(p) for p in patterns]
for pattern in patterns:
if pattern is None:
continue
obj1, obj2, option = pattern.groups()
code = re.sub(obj1, obj2, code)
if option.strip() == "all-casing":
code = re.sub(obj1.lower(), obj2.lower(), code)
code = re.sub(obj1.upper(), obj2.upper(), code)
return code
def find_code_and_splits(object_name: str, base_path: str, buffer: dict = None):
"""Find the code of an object (specified by `object_name`) and split it into blocks.
Args:
object_name (`str`):
The name of the object, e.g. `transformers.models.bert.modeling_bert.BertAttention` or
`tests.models.llama.test_modeling_llama.LlamaModelTest.test_config`.
base_path (`str`):
The path to the base directory within which the search will be performed. It could be either
`TRANSFORMERS_PATH` or `MODEL_TEST_PATH`.
buffer (`dict`, *optional*):
The buffer used to store the previous results in order to speed up the process.
Returns:
lines (`List[str]`):
The lines of the whole file where the object is defined.
code (`str`):
The object's code.
code_splits (`List[Tuple[str, int, int]]`):
`code` splitted into blocks. See `split_code_into_blocks`.
"""
if buffer is None:
buffer = {}
if (object_name, base_path) in buffer:
lines, code, code_splits = buffer[(object_name, base_path)]
else:
code, (lines, target_start_index, target_end_index) = find_code_in_transformers(
object_name, base_path=base_path, return_indices=True
)
indent = get_indent(code)
# Split the code into blocks
# `indent` is the indent of the class/func definition header, but `code_splits` expects the indent level of the
# block body.
code_splits = split_code_into_blocks(
lines, target_start_index, target_end_index, len(indent) + 4, backtrace=True
)
buffer[(object_name, base_path)] = lines, code, code_splits
return lines, code, code_splits
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
@ -285,7 +561,7 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
diff_index += 1
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = None) -> Optional[List[Tuple[str, int]]]:
"""
Check if the code commented as a copy in a file matches the original.
@ -294,11 +570,15 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
The name of the file to check.
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
buffer (`dict`, *optional*):
The buffer used to store the previous results in order to speed up the process.
Returns:
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
with the name of the object having a diff and the line number where theere is the first diff.
"""
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
with open(filename, "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines()
diffs = []
@ -317,16 +597,31 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
# There is some copied code here, let's retrieve the original.
indent, object_name, replace_pattern = search.groups()
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
# Find the file lines, the object's code, and its blocks
target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits(
object_name, base_path, buffer=buffer
)
# code replaced by the patterns
theoretical_code_blocks = OrderedDict()
for name, start, end in theoretical_code_splits:
name = replace_code(name, replace_pattern)
code = "".join(target_lines[start:end])
code = replace_code(code, replace_pattern)
theoretical_code_blocks[name] = code
theoretical_indent = get_indent(theoretical_code)
# `start_index` is the index of the first line (the definition header) after `# Copied from`.
# (`indent != theoretical_indent` doesn't seem to occur so far, not sure what this case is for.)
start_index = line_index + 1 if indent == theoretical_indent else line_index
# enter the block body
line_index = start_index + 1
subcode = "\n".join(theoretical_code.split("\n")[1:])
indent = get_indent(subcode)
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
# We can't call `find_block_end` directly as there is sth. special `# End copy"` here.
should_continue = True
while line_index < len(lines) and should_continue:
line_index += 1
@ -336,33 +631,118 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
# used.
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
# `line_index` is outside the block
# Clean up empty lines at the end (if any).
while len(lines[line_index - 1]) <= 1:
line_index -= 1
observed_code_lines = lines[start_index:line_index]
observed_code = "".join(observed_code_lines)
# Split the observed code into blocks
observed_code_splits = split_code_into_blocks(lines, start_index, line_index, len(indent), backtrace=True)
# Before comparing, use the `replace_pattern` on the original code.
if len(replace_pattern) > 0:
patterns = replace_pattern.replace("with", "").split(",")
patterns = [_re_replace_pattern.search(p) for p in patterns]
for pattern in patterns:
if pattern is None:
continue
obj1, obj2, option = pattern.groups()
theoretical_code = re.sub(obj1, obj2, theoretical_code)
if option.strip() == "all-casing":
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ")
# sanity check
_sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class)
# observed code in a structured way (a dict mapping block names to blocks' code)
observed_code_blocks = OrderedDict()
for name, start, end in observed_code_splits:
code = "".join(lines[start:end])
observed_code_blocks[name] = code
# Below, we change some names in `theoretical_code_blocks` and `observed_code_blocks`. These mappings map the
# original names to the modified names: this is used to restore the original order of the code blocks.
name_mappings_1 = {k: k for k in theoretical_code_blocks.keys()}
name_mappings_2 = {k: k for k in observed_code_blocks.keys()}
# Update code blocks' name and content:
# If `"# Ignore copy"` is found in a block of the observed code:
# 1. if it's a block only in the observed code --> add it to the theoretical code.
# 2. if it's also in the theoretical code () --> put its content (body) to the corresponding block under the
# same name in the theoretical code.
# In both cases, we change the name to have a prefix `_ignored_` so we know if we can discard them during the
# comparison.
ignored_existing_block_index = 0
ignored_new_block_index = 0
for name in list(observed_code_blocks.keys()):
code = observed_code_blocks[name]
if "# Ignore copy" in code:
if name in theoretical_code_blocks:
# in the target --> just copy the content
del theoretical_code_blocks[name]
theoretical_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
name_mappings_1[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
del observed_code_blocks[name]
observed_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
name_mappings_2[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
ignored_existing_block_index += 1
else:
# not in the target --> add it
theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
name_mappings_1[
f"_ignored_new_block_{ignored_new_block_index}"
] = f"_ignored_new_block_{ignored_new_block_index}"
del observed_code_blocks[name]
observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
name_mappings_2[name] = f"_ignored_new_block_{ignored_new_block_index}"
ignored_new_block_index += 1
# Respect the original block order:
# 1. in `theoretical_code_blocks`: the new blocks will follow the existing ones
# 2. in `observed_code_blocks`: the original order are kept with names modified potentially. This is necessary
# to compute the correct `diff_index` if `overwrite=True` and there is a diff.
theoretical_code_blocks = {
name_mappings_1[orig_name]: theoretical_code_blocks[name_mappings_1[orig_name]]
for orig_name in name_mappings_1
}
observed_code_blocks = {
name_mappings_2[orig_name]: observed_code_blocks[name_mappings_2[orig_name]]
for orig_name in name_mappings_2
}
# Ignore the blocks specified to be ignored. This is the version used to check if there is a mismatch
theoretical_code_blocks_clean = {
k: v
for k, v in theoretical_code_blocks.items()
if not (k.startswith(("_ignored_existing_block_", "_ignored_new_block_")))
}
theoretical_code = "".join(list(theoretical_code_blocks_clean.values()))
# stylify `theoretical_code` before compare (this is needed only when `replace_pattern` is not empty)
if replace_pattern:
theoretical_code = stylify(theoretical_code)
# Remove `\n\n` in `theoretical_code` before compare (so no empty line)
while "\n\n" in theoretical_code:
theoretical_code = theoretical_code.replace("\n\n", "\n")
# Compute `observed_code` where we don't include any empty line + keep track the line index between the
# original/processed `observed_code` so we can have the correct `diff_index`.
idx_to_orig_idx_mapping_for_observed_code_lines = {}
idx = -1
orig_idx = -1
observed_code = ""
for name, code in observed_code_blocks.items():
if code.endswith("\n"):
code = code[:-1]
for code_line in code.split("\n"):
orig_idx += 1
if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):
idx += 1
observed_code += code_line + "\n"
idx_to_orig_idx_mapping_for_observed_code_lines[idx] = orig_idx
# Test for a diff and act accordingly.
diff_index = check_codes_match(observed_code, theoretical_code)
if diff_index is not None:
# switch to the index in the original `observed_code` (i.e. before removing empty lines)
diff_index = idx_to_orig_idx_mapping_for_observed_code_lines[diff_index]
diffs.append([object_name, diff_index + start_index + 1])
if overwrite:
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
# `theoretical_code_to_write` is a single string but may have several lines.
theoretical_code_to_write = stylify("".join(list(theoretical_code_blocks.values())))
lines = lines[:start_index] + [theoretical_code_to_write] + lines[line_index:]
# Here we treat it as a single entry in `lines`.
line_index = start_index + 1
if overwrite and len(diffs) > 0:
@ -373,7 +753,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
return diffs
def check_copies(overwrite: bool = False):
def check_copies(overwrite: bool = False, file: str = None):
"""
Check every file is copy-consistent with the original. Also check the model list in the main README and other
READMEs are consistent.
@ -381,14 +761,21 @@ def check_copies(overwrite: bool = False):
Args:
overwrite (`bool`, *optional*, defaults to `False`):
Whether or not to overwrite the copies when they don't match.
file (`bool`, *optional*):
The path to a specific file to check and/or fix.
"""
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
all_files = list(all_files) + list(all_test_files)
buffer = {}
if file is None:
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
all_files = list(all_files) + list(all_test_files)
else:
all_files = [file]
diffs = []
for filename in all_files:
new_diffs = is_copy_consistent(filename, overwrite)
new_diffs = is_copy_consistent(filename, overwrite, buffer)
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
if not overwrite and len(diffs) > 0:
diff = "\n".join(diffs)
@ -733,9 +1120,10 @@ def check_readme(overwrite: bool = False):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--file", type=str, default=None, help="A specific file to check and/or fix")
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
args = parser.parse_args()
check_readme(args.fix_and_overwrite)
check_copies(args.fix_and_overwrite)
check_copies(args.fix_and_overwrite, args.file)
check_full_copies(args.fix_and_overwrite)