diff --git a/tests/models/longformer/test_tokenization_longformer.py b/tests/models/longformer/test_tokenization_longformer.py index 61d8653b60..32dc0f952f 100644 --- a/tests/models/longformer/test_tokenization_longformer.py +++ b/tests/models/longformer/test_tokenization_longformer.py @@ -28,7 +28,9 @@ from ...test_tokenization_common import TokenizerTesterMixin @require_tokenizers +# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer, class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): + # Ignore copy tokenizer_class = LongformerTokenizer test_slow_tokenizer = True rust_tokenizer_class = LongformerTokenizerFast @@ -71,23 +73,19 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): with open(self.merges_file, "w", encoding="utf-8") as fp: fp.write("\n".join(merges)) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer def get_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer def get_rust_tokenizer(self, **kwargs): kwargs.update(self.special_tokens_map) return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts def get_input_output_texts(self, tokenizer): input_text = "lower newer" output_text = "lower newer" return input_text, output_text - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer def test_full_tokenizer(self): tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower newer" @@ -99,7 +97,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19] self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer def longformer_dict_integration_testing(self): tokenizer = self.get_tokenizer() @@ -110,7 +107,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ) @slow - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096 def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096") @@ -130,7 +126,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): assert encoded_sentence == encoded_text_from_decode assert encoded_pair == encoded_pair_from_decode - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding def test_space_encoding(self): tokenizer = self.get_tokenizer() @@ -171,11 +166,9 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0] self.assertNotEqual(first_char, space_encoding) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs def test_pretokenized_inputs(self): pass - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens def test_embeded_special_tokens(self): for tokenizer, pretrained_name, kwargs in self.tokenizers_list: with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): @@ -208,7 +201,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokens_r_str, ["", "A", ",", "", "ĠAllen", "N", "LP", "Ġsentence", ".", ""] ) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args def test_change_add_prefix_space_and_trim_offsets_args(self): for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2): tokenizer_r = self.rust_tokenizer_class.from_pretrained( @@ -223,7 +215,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase): self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space) self.assertEqual(post_processor_state["trim_offsets"], trim_offsets) - # Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self): # Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and # `trim_offsets` diff --git a/tests/repo_utils/test_check_copies.py b/tests/repo_utils/test_check_copies.py index e3e8e47a87..6afed02895 100644 --- a/tests/repo_utils/test_check_copies.py +++ b/tests/repo_utils/test_check_copies.py @@ -95,13 +95,156 @@ class BertCopyModel(BertCopyPreTrainedModel): """ +MOCK_DUMMY_BERT_CODE_MATCH = """ +class BertDummyModel: + attr_1 = 1 + attr_2 = 2 + + def __init__(self, a=1, b=2): + self.a = a + self.b = b + + # Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward + def forward(self, c): + return 1 + + def existing_common(self, c): + return 4 + + def existing_diff_to_be_ignored(self, c): + return 9 +""" + + +MOCK_DUMMY_ROBERTA_CODE_MATCH = """ +# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy +class RobertaBertDummyModel: + + attr_1 = 1 + attr_2 = 2 + + def __init__(self, a=1, b=2): + self.a = a + self.b = b + + # Ignore copy + def only_in_roberta_to_be_ignored(self, c): + return 3 + + # Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward + def forward(self, c): + return 1 + + def existing_common(self, c): + return 4 + + # Ignore copy + def existing_diff_to_be_ignored(self, c): + return 6 +""" + + +MOCK_DUMMY_BERT_CODE_NO_MATCH = """ +class BertDummyModel: + attr_1 = 1 + attr_2 = 2 + + def __init__(self, a=1, b=2): + self.a = a + self.b = b + + # Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward + def forward(self, c): + return 1 + + def only_in_bert(self, c): + return 7 + + def existing_common(self, c): + return 4 + + def existing_diff_not_ignored(self, c): + return 8 + + def existing_diff_to_be_ignored(self, c): + return 9 +""" + + +MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """ +# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy +class RobertaBertDummyModel: + + attr_1 = 1 + attr_2 = 3 + + def __init__(self, a=1, b=2): + self.a = a + self.b = b + + # Ignore copy + def only_in_roberta_to_be_ignored(self, c): + return 3 + + # Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward + def forward(self, c): + return 1 + + def only_in_roberta_not_ignored(self, c): + return 2 + + def existing_common(self, c): + return 4 + + def existing_diff_not_ignored(self, c): + return 5 + + # Ignore copy + def existing_diff_to_be_ignored(self, c): + return 6 +""" + + +EXPECTED_REPLACED_CODE = """ +# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy +class RobertaBertDummyModel: + attr_1 = 1 + attr_2 = 2 + + def __init__(self, a=1, b=2): + self.a = a + self.b = b + + # Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward + def forward(self, c): + return 1 + + def only_in_bert(self, c): + return 7 + + def existing_common(self, c): + return 4 + + def existing_diff_not_ignored(self, c): + return 8 + + # Ignore copy + def existing_diff_to_be_ignored(self, c): + return 6 + + # Ignore copy + def only_in_roberta_to_be_ignored(self, c): + return 3 +""" + + def replace_in_file(filename, old, new): with open(filename, "r", encoding="utf-8") as f: content = f.read() content = content.replace(old, new) - with open(filename, "w", encoding="utf-8") as f: + with open(filename, "w", encoding="utf-8", newline="\n") as f: f.write(content) @@ -117,11 +260,18 @@ def create_tmp_repo(tmp_dir): model_dir = tmp_dir / "src" / "transformers" / "models" model_dir.mkdir(parents=True, exist_ok=True) - models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE} + models = { + "bert": MOCK_BERT_CODE, + "bertcopy": MOCK_BERT_COPY_CODE, + "dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH, + "dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH, + "dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH, + "dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH, + } for model, code in models.items(): model_subdir = model_dir / model model_subdir.mkdir(exist_ok=True) - with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f: + with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f: f.write(code) @@ -176,11 +326,47 @@ class CopyCheckTester(unittest.TestCase): diffs = is_copy_consistent(file_to_check) self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]]) - diffs = is_copy_consistent(file_to_check, overwrite=True) + _ = is_copy_consistent(file_to_check, overwrite=True) with open(file_to_check, "r", encoding="utf-8") as f: self.assertEqual(f.read(), MOCK_BERT_COPY_CODE) + def test_is_copy_consistent_with_ignored_match(self): + path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"] + with tempfile.TemporaryDirectory() as tmp_folder: + # Base check + create_tmp_repo(tmp_folder) + with patch_transformer_repo_path(tmp_folder): + file_to_check = os.path.join(tmp_folder, *path_to_check) + diffs = is_copy_consistent(file_to_check) + self.assertEqual(diffs, []) + + def test_is_copy_consistent_with_ignored_no_match(self): + path_to_check = [ + "src", + "transformers", + "models", + "dummy_roberta_no_match", + "modeling_dummy_roberta_no_match.py", + ] + with tempfile.TemporaryDirectory() as tmp_folder: + # Base check with an inconsistency + create_tmp_repo(tmp_folder) + with patch_transformer_repo_path(tmp_folder): + file_to_check = os.path.join(tmp_folder, *path_to_check) + + diffs = is_copy_consistent(file_to_check) + # line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`. + # (which has a leading `\n`.) + self.assertEqual( + diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]] + ) + + _ = is_copy_consistent(file_to_check, overwrite=True) + + with open(file_to_check, "r", encoding="utf-8") as f: + self.assertEqual(f.read(), EXPECTED_REPLACED_CODE) + def test_convert_to_localized_md(self): localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"] diff --git a/utils/check_copies.py b/utils/check_copies.py index 4fd3a7c23e..3d352cc831 100644 --- a/utils/check_copies.py +++ b/utils/check_copies.py @@ -41,7 +41,8 @@ import glob import os import re import subprocess -from typing import List, Optional, Tuple +from collections import OrderedDict +from typing import List, Optional, Tuple, Union from transformers.utils import direct_transformers_import @@ -125,13 +126,213 @@ LOCALIZED_READMES = { transformers_module = direct_transformers_import(TRANSFORMERS_PATH) +def _is_definition_header_ending_line(line: str) -> bool: + # Helper function. Returns `True` if `line` is the end parenthesis of a class/function definition + return re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None + + def _should_continue(line: str, indent: str) -> bool: # Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a - # function definition - return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None + # class/function definition + return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line) -def find_code_in_transformers(object_name: str, base_path: str = None) -> str: +def _sanity_check_splits(splits_1, splits_2, is_class): + """Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match. + + For the case of `class`, they must be of one of the following 3 cases: + + - a single block without name: + + class foo: + a = 1 + + - a consecutive sequence of (1 or more) blocks with name + + class foo: + + def f(x): + return x + + - a block without name, followed by a consecutive sequence of (1 or more) blocks with name + + class foo: + a = 1 + + def f(x): + return x + + def g(x): + return None + + The 2 code snippets that give `splits_1` and `splits_2` have to be in the same case to pass this check, but the + number of blocks with name in the consecutive sequence is not taken into account. + + For the case of `function or method`, we don't require it to be in one of the above 3 cases. However, the structure + of`splits_1` and `splits_2` have to match exactly. In particular, the number of blocks with name in a consecutive + sequence is taken into account. + """ + block_names_1 = [] + block_names_2 = [] + + for block in splits_1[1:]: + if block[0].startswith("_block_without_name_"): + block_names_1.append("block_without_name") + elif not block[0].startswith("_empty_block_") and ( + not is_class or len(block_names_1) == 0 or block_names_1[-1].startswith("block_without_name") + ): + block_names_1.append("block_with_name") + + for block in splits_2[1:]: + if block[0].startswith("_block_without_name_"): + block_names_2.append("block_without_name") + elif not block[0].startswith("_empty_block_") and ( + not is_class or len(block_names_2) == 0 or block_names_2[-1].startswith("block_without_name") + ): + block_names_2.append("block_with_name") + + if is_class: + if block_names_1 not in [ + ["block_without_name"], + ["block_with_name"], + ["block_without_name", "block_with_name"], + ]: + raise ValueError( + "For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`" + ) + + if block_names_1 != block_names_2: + raise ValueError("The structures in the 2 code blocks differ.") + + +def find_block_end(lines: List[str], start_index: int, indent: int) -> int: + """ + Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`). + + Args: + lines (`List[str]`): + The source code, represented by a list of lines. + start_index (`int`): + The starting index of the target class/func block. + indent (`int`): + The indent of the class/func body. + + Returns: + `int`: The index of the block's ending line plus by 1 (i.e. exclusive). + """ + indent = " " * indent + # enter the block body + line_index = start_index + 1 + + while line_index < len(lines) and _should_continue(lines[line_index], indent): + line_index += 1 + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + return line_index + + +def split_code_into_blocks( + lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False +) -> List[Tuple[str, int, int]]: + """ + Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*. + + The block's header is included as the first element. The contiguous regions (without empty lines) that are not + inside any inner block are included as blocks. The contiguous regions of empty lines that are not inside any inner + block are also included as (dummy) blocks. + + Args: + lines (`List[str]`): + The source code, represented by a list of lines. + start_index (`int`): + The starting index of the target class/func block. + end_index (`int`): + The ending index of the target class/func block. + indent (`int`): + The indent of the class/func body. + backtrace (`bool`, *optional*, defaults to `False`): + Whether or not to include the lines before the inner class/func block's header (e.g. comments, decorators, + etc.) until an empty line is encountered. + + Returns: + `List[Tuple[str, int, int]]`: A list of elements with the form `(block_name, start_index, end_index)`. + """ + splits = [] + # `indent - 4` is the indent level of the target class/func header + target_block_name = re.search(rf"^{' ' * (indent - 4)}((class|def)\s+\S+)(\(|\:)", lines[start_index]).groups()[0] + + # from now on, the `block` means inner blocks unless explicitly specified + indent_str = " " * indent + block_without_name_idx = 0 + empty_block_idx = 0 + + # Find the lines for the definition header + index = start_index + if "(" in lines[start_index] and "):" not in lines[start_index] in lines[start_index]: + while index < end_index: + if _is_definition_header_ending_line(lines[index]): + break + index += 1 + + # the first line outside the definition header + index += 1 + splits.append((target_block_name, start_index, index)) + + block_start_index, prev_block_end_index = index, index + while index < end_index: + # if found, it will be an inner block + block_found = re.search(rf"^{indent_str}((class|def)\s+\S+)(\(|\:)", lines[index]) + if block_found: + name = block_found.groups()[0] + + block_end_index = find_block_end(lines, index, indent + 4) + + # backtrace to include the lines before the found block's definition header (e.g. comments, decorators, + # etc.) until an empty line is encountered. + block_start_index = index + if index > prev_block_end_index and backtrace: + idx = index - 1 + for idx in range(index - 1, prev_block_end_index - 2, -1): + if not (len(lines[idx].strip()) > 0 and lines[idx].startswith(indent_str)): + break + idx += 1 + if idx < index: + block_start_index = idx + + # between the current found block and the previous found block + if block_start_index > prev_block_end_index: + # give it a dummy name + if len("".join(lines[prev_block_end_index:block_start_index]).strip()) == 0: + prev_block_name = f"_empty_block_{empty_block_idx}" + empty_block_idx += 1 + else: + prev_block_name = f"_block_without_name_{block_without_name_idx}" + block_without_name_idx += 1 + # Add it as a block + splits.append((prev_block_name, prev_block_end_index, block_start_index)) + + # Add the current found block + splits.append((name, block_start_index, block_end_index)) + prev_block_end_index = block_end_index + index = block_end_index - 1 + + index += 1 + + if index > prev_block_end_index: + if len("".join(lines[prev_block_end_index:index]).strip()) == 0: + prev_block_name = f"_empty_block_{empty_block_idx}" + else: + prev_block_name = f"_block_without_name_{block_without_name_idx}" + splits.append((prev_block_name, prev_block_end_index, index)) + + return splits + + +def find_code_in_transformers( + object_name: str, base_path: str = None, return_indices: bool = False +) -> Union[str, Tuple[List[str], int, int]]: """ Find and return the source code of an object. @@ -140,9 +341,15 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str: The name of the object we want the source code of. base_path (`str`, *optional*): The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`. + return_indices(`bool`, *optional*, defaults to `False`): + If `False`, will only return the code (as a string), otherwise it will also return the whole lines of the + file where the object specified by `object_name` is defined, together the start/end indices of the block in + the file that defines the object. Returns: - `str`: The source code of the object. + `Union[str, Tuple[List[str], int, int]]`: If `return_indices=False`, only the source code of the object will be + returned. Otherwise, it also returns the whole lines of the file where the object specified by `object_name` is + defined, together the start/end indices of the block in the file that defines the object. """ parts = object_name.split(".") i = 0 @@ -181,22 +388,91 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str: line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None ): line_index += 1 + # find the target specified in the current level in `parts` -> increase `indent` so we can search the next indent += " " + # the index of the first line in the (currently found) block *body* line_index += 1 if line_index >= len(lines): raise ValueError(f" {object_name} does not match any function or class in {module}.") - # We found the beginning of the class / func, now let's find the end (when the indent diminishes). - start_index = line_index - 1 - while line_index < len(lines) and _should_continue(lines[line_index], indent): - line_index += 1 - # Clean up empty lines at the end (if any). - while len(lines[line_index - 1]) <= 1: - line_index -= 1 + # `indent` is already one level deeper than the (found) class/func block's definition header - code_lines = lines[start_index:line_index] - return "".join(code_lines) + # We found the beginning of the class / func, now let's find the end (when the indent diminishes). + # `start_index` is the index of the class/func block's definition header + start_index = line_index - 1 + end_index = find_block_end(lines, start_index, len(indent)) + + code = "".join(lines[start_index:end_index]) + return (code, (lines, start_index, end_index)) if return_indices else code + + +def replace_code(code: str, replace_pattern: str) -> str: + """Replace `code` by a pattern of the form `with X1->X2,Y1->Y2,Z1->Z2`. + + Args: + code (`str`): The code to be modified. + replace_pattern (`str`): The pattern used to modify `code`. + + Returns: + `str`: The modified code. + """ + if len(replace_pattern) > 0: + patterns = replace_pattern.replace("with", "").split(",") + patterns = [_re_replace_pattern.search(p) for p in patterns] + for pattern in patterns: + if pattern is None: + continue + obj1, obj2, option = pattern.groups() + code = re.sub(obj1, obj2, code) + if option.strip() == "all-casing": + code = re.sub(obj1.lower(), obj2.lower(), code) + code = re.sub(obj1.upper(), obj2.upper(), code) + + return code + + +def find_code_and_splits(object_name: str, base_path: str, buffer: dict = None): + """Find the code of an object (specified by `object_name`) and split it into blocks. + + Args: + object_name (`str`): + The name of the object, e.g. `transformers.models.bert.modeling_bert.BertAttention` or + `tests.models.llama.test_modeling_llama.LlamaModelTest.test_config`. + base_path (`str`): + The path to the base directory within which the search will be performed. It could be either + `TRANSFORMERS_PATH` or `MODEL_TEST_PATH`. + buffer (`dict`, *optional*): + The buffer used to store the previous results in order to speed up the process. + + Returns: + lines (`List[str]`): + The lines of the whole file where the object is defined. + code (`str`): + The object's code. + code_splits (`List[Tuple[str, int, int]]`): + `code` splitted into blocks. See `split_code_into_blocks`. + """ + if buffer is None: + buffer = {} + + if (object_name, base_path) in buffer: + lines, code, code_splits = buffer[(object_name, base_path)] + else: + code, (lines, target_start_index, target_end_index) = find_code_in_transformers( + object_name, base_path=base_path, return_indices=True + ) + indent = get_indent(code) + + # Split the code into blocks + # `indent` is the indent of the class/func definition header, but `code_splits` expects the indent level of the + # block body. + code_splits = split_code_into_blocks( + lines, target_start_index, target_end_index, len(indent) + 4, backtrace=True + ) + buffer[(object_name, base_path)] = lines, code, code_splits + + return lines, code, code_splits _re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)") @@ -285,7 +561,7 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int diff_index += 1 -def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]: +def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = None) -> Optional[List[Tuple[str, int]]]: """ Check if the code commented as a copy in a file matches the original. @@ -294,11 +570,15 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[ The name of the file to check. overwrite (`bool`, *optional*, defaults to `False`): Whether or not to overwrite the copies when they don't match. + buffer (`dict`, *optional*): + The buffer used to store the previous results in order to speed up the process. Returns: `Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)` with the name of the object having a diff and the line number where theere is the first diff. """ + base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH + with open(filename, "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() diffs = [] @@ -317,16 +597,31 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[ # There is some copied code here, let's retrieve the original. indent, object_name, replace_pattern = search.groups() - base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH - theoretical_code = find_code_in_transformers(object_name, base_path=base_path) + # Find the file lines, the object's code, and its blocks + target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits( + object_name, base_path, buffer=buffer + ) + + # code replaced by the patterns + theoretical_code_blocks = OrderedDict() + for name, start, end in theoretical_code_splits: + name = replace_code(name, replace_pattern) + code = "".join(target_lines[start:end]) + code = replace_code(code, replace_pattern) + theoretical_code_blocks[name] = code + theoretical_indent = get_indent(theoretical_code) + # `start_index` is the index of the first line (the definition header) after `# Copied from`. + # (`indent != theoretical_indent` doesn't seem to occur so far, not sure what this case is for.) start_index = line_index + 1 if indent == theoretical_indent else line_index + # enter the block body line_index = start_index + 1 subcode = "\n".join(theoretical_code.split("\n")[1:]) indent = get_indent(subcode) # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment. + # We can't call `find_block_end` directly as there is sth. special `# End copy"` here. should_continue = True while line_index < len(lines) and should_continue: line_index += 1 @@ -336,33 +631,118 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[ # There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be # used. should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None + # `line_index` is outside the block # Clean up empty lines at the end (if any). while len(lines[line_index - 1]) <= 1: line_index -= 1 - observed_code_lines = lines[start_index:line_index] - observed_code = "".join(observed_code_lines) + # Split the observed code into blocks + observed_code_splits = split_code_into_blocks(lines, start_index, line_index, len(indent), backtrace=True) - # Before comparing, use the `replace_pattern` on the original code. - if len(replace_pattern) > 0: - patterns = replace_pattern.replace("with", "").split(",") - patterns = [_re_replace_pattern.search(p) for p in patterns] - for pattern in patterns: - if pattern is None: - continue - obj1, obj2, option = pattern.groups() - theoretical_code = re.sub(obj1, obj2, theoretical_code) - if option.strip() == "all-casing": - theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) - theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) + is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ") + # sanity check + _sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class) + + # observed code in a structured way (a dict mapping block names to blocks' code) + observed_code_blocks = OrderedDict() + for name, start, end in observed_code_splits: + code = "".join(lines[start:end]) + observed_code_blocks[name] = code + + # Below, we change some names in `theoretical_code_blocks` and `observed_code_blocks`. These mappings map the + # original names to the modified names: this is used to restore the original order of the code blocks. + name_mappings_1 = {k: k for k in theoretical_code_blocks.keys()} + name_mappings_2 = {k: k for k in observed_code_blocks.keys()} + + # Update code blocks' name and content: + # If `"# Ignore copy"` is found in a block of the observed code: + # 1. if it's a block only in the observed code --> add it to the theoretical code. + # 2. if it's also in the theoretical code () --> put its content (body) to the corresponding block under the + # same name in the theoretical code. + # In both cases, we change the name to have a prefix `_ignored_` so we know if we can discard them during the + # comparison. + ignored_existing_block_index = 0 + ignored_new_block_index = 0 + for name in list(observed_code_blocks.keys()): + code = observed_code_blocks[name] + if "# Ignore copy" in code: + if name in theoretical_code_blocks: + # in the target --> just copy the content + del theoretical_code_blocks[name] + theoretical_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code + name_mappings_1[name] = f"_ignored_existing_block_{ignored_existing_block_index}" + + del observed_code_blocks[name] + observed_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code + name_mappings_2[name] = f"_ignored_existing_block_{ignored_existing_block_index}" + ignored_existing_block_index += 1 + else: + # not in the target --> add it + theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code + name_mappings_1[ + f"_ignored_new_block_{ignored_new_block_index}" + ] = f"_ignored_new_block_{ignored_new_block_index}" + + del observed_code_blocks[name] + observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code + name_mappings_2[name] = f"_ignored_new_block_{ignored_new_block_index}" + ignored_new_block_index += 1 + + # Respect the original block order: + # 1. in `theoretical_code_blocks`: the new blocks will follow the existing ones + # 2. in `observed_code_blocks`: the original order are kept with names modified potentially. This is necessary + # to compute the correct `diff_index` if `overwrite=True` and there is a diff. + theoretical_code_blocks = { + name_mappings_1[orig_name]: theoretical_code_blocks[name_mappings_1[orig_name]] + for orig_name in name_mappings_1 + } + observed_code_blocks = { + name_mappings_2[orig_name]: observed_code_blocks[name_mappings_2[orig_name]] + for orig_name in name_mappings_2 + } + + # Ignore the blocks specified to be ignored. This is the version used to check if there is a mismatch + theoretical_code_blocks_clean = { + k: v + for k, v in theoretical_code_blocks.items() + if not (k.startswith(("_ignored_existing_block_", "_ignored_new_block_"))) + } + theoretical_code = "".join(list(theoretical_code_blocks_clean.values())) + + # stylify `theoretical_code` before compare (this is needed only when `replace_pattern` is not empty) + if replace_pattern: theoretical_code = stylify(theoretical_code) + # Remove `\n\n` in `theoretical_code` before compare (so no empty line) + while "\n\n" in theoretical_code: + theoretical_code = theoretical_code.replace("\n\n", "\n") + + # Compute `observed_code` where we don't include any empty line + keep track the line index between the + # original/processed `observed_code` so we can have the correct `diff_index`. + idx_to_orig_idx_mapping_for_observed_code_lines = {} + idx = -1 + orig_idx = -1 + observed_code = "" + for name, code in observed_code_blocks.items(): + if code.endswith("\n"): + code = code[:-1] + for code_line in code.split("\n"): + orig_idx += 1 + if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")): + idx += 1 + observed_code += code_line + "\n" + idx_to_orig_idx_mapping_for_observed_code_lines[idx] = orig_idx # Test for a diff and act accordingly. diff_index = check_codes_match(observed_code, theoretical_code) if diff_index is not None: + # switch to the index in the original `observed_code` (i.e. before removing empty lines) + diff_index = idx_to_orig_idx_mapping_for_observed_code_lines[diff_index] diffs.append([object_name, diff_index + start_index + 1]) if overwrite: - lines = lines[:start_index] + [theoretical_code] + lines[line_index:] + # `theoretical_code_to_write` is a single string but may have several lines. + theoretical_code_to_write = stylify("".join(list(theoretical_code_blocks.values()))) + lines = lines[:start_index] + [theoretical_code_to_write] + lines[line_index:] + # Here we treat it as a single entry in `lines`. line_index = start_index + 1 if overwrite and len(diffs) > 0: @@ -373,7 +753,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[ return diffs -def check_copies(overwrite: bool = False): +def check_copies(overwrite: bool = False, file: str = None): """ Check every file is copy-consistent with the original. Also check the model list in the main README and other READMEs are consistent. @@ -381,14 +761,21 @@ def check_copies(overwrite: bool = False): Args: overwrite (`bool`, *optional*, defaults to `False`): Whether or not to overwrite the copies when they don't match. + file (`bool`, *optional*): + The path to a specific file to check and/or fix. """ - all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) - all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True) - all_files = list(all_files) + list(all_test_files) + buffer = {} + + if file is None: + all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True) + all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True) + all_files = list(all_files) + list(all_test_files) + else: + all_files = [file] diffs = [] for filename in all_files: - new_diffs = is_copy_consistent(filename, overwrite) + new_diffs = is_copy_consistent(filename, overwrite, buffer) diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] if not overwrite and len(diffs) > 0: diff = "\n".join(diffs) @@ -733,9 +1120,10 @@ def check_readme(overwrite: bool = False): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--file", type=str, default=None, help="A specific file to check and/or fix") parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") args = parser.parse_args() check_readme(args.fix_and_overwrite) - check_copies(args.fix_and_overwrite) + check_copies(args.fix_and_overwrite, args.file) check_full_copies(args.fix_and_overwrite)