Allow `# Ignore copy` (#27328)
* fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
44b5506d29
commit
52746922b0
|
@ -28,7 +28,9 @@ from ...test_tokenization_common import TokenizerTesterMixin
|
|||
|
||||
|
||||
@require_tokenizers
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest with roberta-base->allenai/longformer-base-4096,Roberta->Longformer,roberta->longformer,
|
||||
class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
# Ignore copy
|
||||
tokenizer_class = LongformerTokenizer
|
||||
test_slow_tokenizer = True
|
||||
rust_tokenizer_class = LongformerTokenizerFast
|
||||
|
@ -71,23 +73,19 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
with open(self.merges_file, "w", encoding="utf-8") as fp:
|
||||
fp.write("\n".join(merges))
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_tokenizer
|
||||
def get_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_rust_tokenizer
|
||||
def get_rust_tokenizer(self, **kwargs):
|
||||
kwargs.update(self.special_tokens_map)
|
||||
return self.rust_tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.get_input_output_texts
|
||||
def get_input_output_texts(self, tokenizer):
|
||||
input_text = "lower newer"
|
||||
output_text = "lower newer"
|
||||
return input_text, output_text
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_full_tokenizer
|
||||
def test_full_tokenizer(self):
|
||||
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
|
||||
text = "lower newer"
|
||||
|
@ -99,7 +97,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
input_bpe_tokens = [0, 1, 2, 15, 10, 9, 3, 2, 15, 19]
|
||||
self.assertListEqual(tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.roberta_dict_integration_testing with roberta->longformer
|
||||
def longformer_dict_integration_testing(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
|
@ -110,7 +107,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
)
|
||||
|
||||
@slow
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_sequence_builders with roberta-base->allenai/longformer-base-4096
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("allenai/longformer-base-4096")
|
||||
|
||||
|
@ -130,7 +126,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
assert encoded_sentence == encoded_text_from_decode
|
||||
assert encoded_pair == encoded_pair_from_decode
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_space_encoding
|
||||
def test_space_encoding(self):
|
||||
tokenizer = self.get_tokenizer()
|
||||
|
||||
|
@ -171,11 +166,9 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
first_char = tokenizer.convert_ids_to_tokens(encoded[mask_loc + 1])[0]
|
||||
self.assertNotEqual(first_char, space_encoding)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_pretokenized_inputs
|
||||
def test_pretokenized_inputs(self):
|
||||
pass
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_embeded_special_tokens
|
||||
def test_embeded_special_tokens(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
|
@ -208,7 +201,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
||||
)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_change_add_prefix_space_and_trim_offsets_args
|
||||
def test_change_add_prefix_space_and_trim_offsets_args(self):
|
||||
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
|
@ -223,7 +215,6 @@ class LongformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
|
||||
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
|
||||
|
||||
# Copied from tests.models.roberta.test_tokenization_roberta.RobertaTokenizationTest.test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments
|
||||
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
|
||||
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
|
||||
# `trim_offsets`
|
||||
|
|
|
@ -95,13 +95,156 @@ class BertCopyModel(BertCopyPreTrainedModel):
|
|||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_match.modeling_dummy_bert_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_BERT_CODE_NO_MATCH = """
|
||||
class BertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 9
|
||||
"""
|
||||
|
||||
|
||||
MOCK_DUMMY_ROBERTA_CODE_NO_MATCH = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
|
||||
attr_1 = 1
|
||||
attr_2 = 3
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_roberta_not_ignored(self, c):
|
||||
return 2
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 5
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
"""
|
||||
|
||||
|
||||
EXPECTED_REPLACED_CODE = """
|
||||
# Copied from transformers.models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel with BertDummy->RobertaBertDummy
|
||||
class RobertaBertDummyModel:
|
||||
attr_1 = 1
|
||||
attr_2 = 2
|
||||
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
# Copied from transformers.models.dummy_gpt2.modeling_dummy_gpt2.GPT2DummyModel.forward
|
||||
def forward(self, c):
|
||||
return 1
|
||||
|
||||
def only_in_bert(self, c):
|
||||
return 7
|
||||
|
||||
def existing_common(self, c):
|
||||
return 4
|
||||
|
||||
def existing_diff_not_ignored(self, c):
|
||||
return 8
|
||||
|
||||
# Ignore copy
|
||||
def existing_diff_to_be_ignored(self, c):
|
||||
return 6
|
||||
|
||||
# Ignore copy
|
||||
def only_in_roberta_to_be_ignored(self, c):
|
||||
return 3
|
||||
"""
|
||||
|
||||
|
||||
def replace_in_file(filename, old, new):
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
content = content.replace(old, new)
|
||||
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
with open(filename, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
|
@ -117,11 +260,18 @@ def create_tmp_repo(tmp_dir):
|
|||
model_dir = tmp_dir / "src" / "transformers" / "models"
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = {"bert": MOCK_BERT_CODE, "bertcopy": MOCK_BERT_COPY_CODE}
|
||||
models = {
|
||||
"bert": MOCK_BERT_CODE,
|
||||
"bertcopy": MOCK_BERT_COPY_CODE,
|
||||
"dummy_bert_match": MOCK_DUMMY_BERT_CODE_MATCH,
|
||||
"dummy_roberta_match": MOCK_DUMMY_ROBERTA_CODE_MATCH,
|
||||
"dummy_bert_no_match": MOCK_DUMMY_BERT_CODE_NO_MATCH,
|
||||
"dummy_roberta_no_match": MOCK_DUMMY_ROBERTA_CODE_NO_MATCH,
|
||||
}
|
||||
for model, code in models.items():
|
||||
model_subdir = model_dir / model
|
||||
model_subdir.mkdir(exist_ok=True)
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8") as f:
|
||||
with open(model_subdir / f"modeling_{model}.py", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(code)
|
||||
|
||||
|
||||
|
@ -176,11 +326,47 @@ class CopyCheckTester(unittest.TestCase):
|
|||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [["models.bert.modeling_bert.BertModel", 22]])
|
||||
|
||||
diffs = is_copy_consistent(file_to_check, overwrite=True)
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), MOCK_BERT_COPY_CODE)
|
||||
|
||||
def test_is_copy_consistent_with_ignored_match(self):
|
||||
path_to_check = ["src", "transformers", "models", "dummy_roberta_match", "modeling_dummy_roberta_match.py"]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
self.assertEqual(diffs, [])
|
||||
|
||||
def test_is_copy_consistent_with_ignored_no_match(self):
|
||||
path_to_check = [
|
||||
"src",
|
||||
"transformers",
|
||||
"models",
|
||||
"dummy_roberta_no_match",
|
||||
"modeling_dummy_roberta_no_match.py",
|
||||
]
|
||||
with tempfile.TemporaryDirectory() as tmp_folder:
|
||||
# Base check with an inconsistency
|
||||
create_tmp_repo(tmp_folder)
|
||||
with patch_transformer_repo_path(tmp_folder):
|
||||
file_to_check = os.path.join(tmp_folder, *path_to_check)
|
||||
|
||||
diffs = is_copy_consistent(file_to_check)
|
||||
# line 6: `attr_2 = 3` in `MOCK_DUMMY_ROBERTA_CODE_NO_MATCH`.
|
||||
# (which has a leading `\n`.)
|
||||
self.assertEqual(
|
||||
diffs, [["models.dummy_bert_no_match.modeling_dummy_bert_no_match.BertDummyModel", 6]]
|
||||
)
|
||||
|
||||
_ = is_copy_consistent(file_to_check, overwrite=True)
|
||||
|
||||
with open(file_to_check, "r", encoding="utf-8") as f:
|
||||
self.assertEqual(f.read(), EXPECTED_REPLACED_CODE)
|
||||
|
||||
def test_convert_to_localized_md(self):
|
||||
localized_readme = check_copies.LOCALIZED_READMES["README_zh-hans.md"]
|
||||
|
||||
|
|
|
@ -41,7 +41,8 @@ import glob
|
|||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import List, Optional, Tuple
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from transformers.utils import direct_transformers_import
|
||||
|
||||
|
@ -125,13 +126,213 @@ LOCALIZED_READMES = {
|
|||
transformers_module = direct_transformers_import(TRANSFORMERS_PATH)
|
||||
|
||||
|
||||
def _is_definition_header_ending_line(line: str) -> bool:
|
||||
# Helper function. Returns `True` if `line` is the end parenthesis of a class/function definition
|
||||
return re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
|
||||
|
||||
def _should_continue(line: str, indent: str) -> bool:
|
||||
# Helper function. Returns `True` if `line` is empty, starts with the `indent` or is the end parenthesis of a
|
||||
# function definition
|
||||
return line.startswith(indent) or len(line.strip()) == 0 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None
|
||||
# class/function definition
|
||||
return line.startswith(indent) or len(line.strip()) == 0 or _is_definition_header_ending_line(line)
|
||||
|
||||
|
||||
def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
||||
def _sanity_check_splits(splits_1, splits_2, is_class):
|
||||
"""Check the two (inner) block structures of the corresponding code block given by `split_code_into_blocks` match.
|
||||
|
||||
For the case of `class`, they must be of one of the following 3 cases:
|
||||
|
||||
- a single block without name:
|
||||
|
||||
class foo:
|
||||
a = 1
|
||||
|
||||
- a consecutive sequence of (1 or more) blocks with name
|
||||
|
||||
class foo:
|
||||
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
- a block without name, followed by a consecutive sequence of (1 or more) blocks with name
|
||||
|
||||
class foo:
|
||||
a = 1
|
||||
|
||||
def f(x):
|
||||
return x
|
||||
|
||||
def g(x):
|
||||
return None
|
||||
|
||||
The 2 code snippets that give `splits_1` and `splits_2` have to be in the same case to pass this check, but the
|
||||
number of blocks with name in the consecutive sequence is not taken into account.
|
||||
|
||||
For the case of `function or method`, we don't require it to be in one of the above 3 cases. However, the structure
|
||||
of`splits_1` and `splits_2` have to match exactly. In particular, the number of blocks with name in a consecutive
|
||||
sequence is taken into account.
|
||||
"""
|
||||
block_names_1 = []
|
||||
block_names_2 = []
|
||||
|
||||
for block in splits_1[1:]:
|
||||
if block[0].startswith("_block_without_name_"):
|
||||
block_names_1.append("block_without_name")
|
||||
elif not block[0].startswith("_empty_block_") and (
|
||||
not is_class or len(block_names_1) == 0 or block_names_1[-1].startswith("block_without_name")
|
||||
):
|
||||
block_names_1.append("block_with_name")
|
||||
|
||||
for block in splits_2[1:]:
|
||||
if block[0].startswith("_block_without_name_"):
|
||||
block_names_2.append("block_without_name")
|
||||
elif not block[0].startswith("_empty_block_") and (
|
||||
not is_class or len(block_names_2) == 0 or block_names_2[-1].startswith("block_without_name")
|
||||
):
|
||||
block_names_2.append("block_with_name")
|
||||
|
||||
if is_class:
|
||||
if block_names_1 not in [
|
||||
["block_without_name"],
|
||||
["block_with_name"],
|
||||
["block_without_name", "block_with_name"],
|
||||
]:
|
||||
raise ValueError(
|
||||
"For a class, it must have a specific structure. See the docstring of `_sanity_check_splits` in the file `utils/check_copies.py`"
|
||||
)
|
||||
|
||||
if block_names_1 != block_names_2:
|
||||
raise ValueError("The structures in the 2 code blocks differ.")
|
||||
|
||||
|
||||
def find_block_end(lines: List[str], start_index: int, indent: int) -> int:
|
||||
"""
|
||||
Find the end of the class/func block starting at `start_index` in a source code (defined by `lines`).
|
||||
|
||||
Args:
|
||||
lines (`List[str]`):
|
||||
The source code, represented by a list of lines.
|
||||
start_index (`int`):
|
||||
The starting index of the target class/func block.
|
||||
indent (`int`):
|
||||
The indent of the class/func body.
|
||||
|
||||
Returns:
|
||||
`int`: The index of the block's ending line plus by 1 (i.e. exclusive).
|
||||
"""
|
||||
indent = " " * indent
|
||||
# enter the block body
|
||||
line_index = start_index + 1
|
||||
|
||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
return line_index
|
||||
|
||||
|
||||
def split_code_into_blocks(
|
||||
lines: List[str], start_index: int, end_index: int, indent: int, backtrace: bool = False
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
"""
|
||||
Split the class/func block starting at `start_index` in a source code (defined by `lines`) into *inner blocks*.
|
||||
|
||||
The block's header is included as the first element. The contiguous regions (without empty lines) that are not
|
||||
inside any inner block are included as blocks. The contiguous regions of empty lines that are not inside any inner
|
||||
block are also included as (dummy) blocks.
|
||||
|
||||
Args:
|
||||
lines (`List[str]`):
|
||||
The source code, represented by a list of lines.
|
||||
start_index (`int`):
|
||||
The starting index of the target class/func block.
|
||||
end_index (`int`):
|
||||
The ending index of the target class/func block.
|
||||
indent (`int`):
|
||||
The indent of the class/func body.
|
||||
backtrace (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to include the lines before the inner class/func block's header (e.g. comments, decorators,
|
||||
etc.) until an empty line is encountered.
|
||||
|
||||
Returns:
|
||||
`List[Tuple[str, int, int]]`: A list of elements with the form `(block_name, start_index, end_index)`.
|
||||
"""
|
||||
splits = []
|
||||
# `indent - 4` is the indent level of the target class/func header
|
||||
target_block_name = re.search(rf"^{' ' * (indent - 4)}((class|def)\s+\S+)(\(|\:)", lines[start_index]).groups()[0]
|
||||
|
||||
# from now on, the `block` means inner blocks unless explicitly specified
|
||||
indent_str = " " * indent
|
||||
block_without_name_idx = 0
|
||||
empty_block_idx = 0
|
||||
|
||||
# Find the lines for the definition header
|
||||
index = start_index
|
||||
if "(" in lines[start_index] and "):" not in lines[start_index] in lines[start_index]:
|
||||
while index < end_index:
|
||||
if _is_definition_header_ending_line(lines[index]):
|
||||
break
|
||||
index += 1
|
||||
|
||||
# the first line outside the definition header
|
||||
index += 1
|
||||
splits.append((target_block_name, start_index, index))
|
||||
|
||||
block_start_index, prev_block_end_index = index, index
|
||||
while index < end_index:
|
||||
# if found, it will be an inner block
|
||||
block_found = re.search(rf"^{indent_str}((class|def)\s+\S+)(\(|\:)", lines[index])
|
||||
if block_found:
|
||||
name = block_found.groups()[0]
|
||||
|
||||
block_end_index = find_block_end(lines, index, indent + 4)
|
||||
|
||||
# backtrace to include the lines before the found block's definition header (e.g. comments, decorators,
|
||||
# etc.) until an empty line is encountered.
|
||||
block_start_index = index
|
||||
if index > prev_block_end_index and backtrace:
|
||||
idx = index - 1
|
||||
for idx in range(index - 1, prev_block_end_index - 2, -1):
|
||||
if not (len(lines[idx].strip()) > 0 and lines[idx].startswith(indent_str)):
|
||||
break
|
||||
idx += 1
|
||||
if idx < index:
|
||||
block_start_index = idx
|
||||
|
||||
# between the current found block and the previous found block
|
||||
if block_start_index > prev_block_end_index:
|
||||
# give it a dummy name
|
||||
if len("".join(lines[prev_block_end_index:block_start_index]).strip()) == 0:
|
||||
prev_block_name = f"_empty_block_{empty_block_idx}"
|
||||
empty_block_idx += 1
|
||||
else:
|
||||
prev_block_name = f"_block_without_name_{block_without_name_idx}"
|
||||
block_without_name_idx += 1
|
||||
# Add it as a block
|
||||
splits.append((prev_block_name, prev_block_end_index, block_start_index))
|
||||
|
||||
# Add the current found block
|
||||
splits.append((name, block_start_index, block_end_index))
|
||||
prev_block_end_index = block_end_index
|
||||
index = block_end_index - 1
|
||||
|
||||
index += 1
|
||||
|
||||
if index > prev_block_end_index:
|
||||
if len("".join(lines[prev_block_end_index:index]).strip()) == 0:
|
||||
prev_block_name = f"_empty_block_{empty_block_idx}"
|
||||
else:
|
||||
prev_block_name = f"_block_without_name_{block_without_name_idx}"
|
||||
splits.append((prev_block_name, prev_block_end_index, index))
|
||||
|
||||
return splits
|
||||
|
||||
|
||||
def find_code_in_transformers(
|
||||
object_name: str, base_path: str = None, return_indices: bool = False
|
||||
) -> Union[str, Tuple[List[str], int, int]]:
|
||||
"""
|
||||
Find and return the source code of an object.
|
||||
|
||||
|
@ -140,9 +341,15 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
|||
The name of the object we want the source code of.
|
||||
base_path (`str`, *optional*):
|
||||
The path to the base folder where files are checked. If not set, it will be set to `TRANSFORMERS_PATH`.
|
||||
return_indices(`bool`, *optional*, defaults to `False`):
|
||||
If `False`, will only return the code (as a string), otherwise it will also return the whole lines of the
|
||||
file where the object specified by `object_name` is defined, together the start/end indices of the block in
|
||||
the file that defines the object.
|
||||
|
||||
Returns:
|
||||
`str`: The source code of the object.
|
||||
`Union[str, Tuple[List[str], int, int]]`: If `return_indices=False`, only the source code of the object will be
|
||||
returned. Otherwise, it also returns the whole lines of the file where the object specified by `object_name` is
|
||||
defined, together the start/end indices of the block in the file that defines the object.
|
||||
"""
|
||||
parts = object_name.split(".")
|
||||
i = 0
|
||||
|
@ -181,22 +388,91 @@ def find_code_in_transformers(object_name: str, base_path: str = None) -> str:
|
|||
line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None
|
||||
):
|
||||
line_index += 1
|
||||
# find the target specified in the current level in `parts` -> increase `indent` so we can search the next
|
||||
indent += " "
|
||||
# the index of the first line in the (currently found) block *body*
|
||||
line_index += 1
|
||||
|
||||
if line_index >= len(lines):
|
||||
raise ValueError(f" {object_name} does not match any function or class in {module}.")
|
||||
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
start_index = line_index - 1
|
||||
while line_index < len(lines) and _should_continue(lines[line_index], indent):
|
||||
line_index += 1
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
# `indent` is already one level deeper than the (found) class/func block's definition header
|
||||
|
||||
code_lines = lines[start_index:line_index]
|
||||
return "".join(code_lines)
|
||||
# We found the beginning of the class / func, now let's find the end (when the indent diminishes).
|
||||
# `start_index` is the index of the class/func block's definition header
|
||||
start_index = line_index - 1
|
||||
end_index = find_block_end(lines, start_index, len(indent))
|
||||
|
||||
code = "".join(lines[start_index:end_index])
|
||||
return (code, (lines, start_index, end_index)) if return_indices else code
|
||||
|
||||
|
||||
def replace_code(code: str, replace_pattern: str) -> str:
|
||||
"""Replace `code` by a pattern of the form `with X1->X2,Y1->Y2,Z1->Z2`.
|
||||
|
||||
Args:
|
||||
code (`str`): The code to be modified.
|
||||
replace_pattern (`str`): The pattern used to modify `code`.
|
||||
|
||||
Returns:
|
||||
`str`: The modified code.
|
||||
"""
|
||||
if len(replace_pattern) > 0:
|
||||
patterns = replace_pattern.replace("with", "").split(",")
|
||||
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
||||
for pattern in patterns:
|
||||
if pattern is None:
|
||||
continue
|
||||
obj1, obj2, option = pattern.groups()
|
||||
code = re.sub(obj1, obj2, code)
|
||||
if option.strip() == "all-casing":
|
||||
code = re.sub(obj1.lower(), obj2.lower(), code)
|
||||
code = re.sub(obj1.upper(), obj2.upper(), code)
|
||||
|
||||
return code
|
||||
|
||||
|
||||
def find_code_and_splits(object_name: str, base_path: str, buffer: dict = None):
|
||||
"""Find the code of an object (specified by `object_name`) and split it into blocks.
|
||||
|
||||
Args:
|
||||
object_name (`str`):
|
||||
The name of the object, e.g. `transformers.models.bert.modeling_bert.BertAttention` or
|
||||
`tests.models.llama.test_modeling_llama.LlamaModelTest.test_config`.
|
||||
base_path (`str`):
|
||||
The path to the base directory within which the search will be performed. It could be either
|
||||
`TRANSFORMERS_PATH` or `MODEL_TEST_PATH`.
|
||||
buffer (`dict`, *optional*):
|
||||
The buffer used to store the previous results in order to speed up the process.
|
||||
|
||||
Returns:
|
||||
lines (`List[str]`):
|
||||
The lines of the whole file where the object is defined.
|
||||
code (`str`):
|
||||
The object's code.
|
||||
code_splits (`List[Tuple[str, int, int]]`):
|
||||
`code` splitted into blocks. See `split_code_into_blocks`.
|
||||
"""
|
||||
if buffer is None:
|
||||
buffer = {}
|
||||
|
||||
if (object_name, base_path) in buffer:
|
||||
lines, code, code_splits = buffer[(object_name, base_path)]
|
||||
else:
|
||||
code, (lines, target_start_index, target_end_index) = find_code_in_transformers(
|
||||
object_name, base_path=base_path, return_indices=True
|
||||
)
|
||||
indent = get_indent(code)
|
||||
|
||||
# Split the code into blocks
|
||||
# `indent` is the indent of the class/func definition header, but `code_splits` expects the indent level of the
|
||||
# block body.
|
||||
code_splits = split_code_into_blocks(
|
||||
lines, target_start_index, target_end_index, len(indent) + 4, backtrace=True
|
||||
)
|
||||
buffer[(object_name, base_path)] = lines, code, code_splits
|
||||
|
||||
return lines, code, code_splits
|
||||
|
||||
|
||||
_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+transformers\.(\S+\.\S+)\s*($|\S.*$)")
|
||||
|
@ -285,7 +561,7 @@ def check_codes_match(observed_code: str, theoretical_code: str) -> Optional[int
|
|||
diff_index += 1
|
||||
|
||||
|
||||
def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[Tuple[str, int]]]:
|
||||
def is_copy_consistent(filename: str, overwrite: bool = False, buffer: dict = None) -> Optional[List[Tuple[str, int]]]:
|
||||
"""
|
||||
Check if the code commented as a copy in a file matches the original.
|
||||
|
||||
|
@ -294,11 +570,15 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||
The name of the file to check.
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
buffer (`dict`, *optional*):
|
||||
The buffer used to store the previous results in order to speed up the process.
|
||||
|
||||
Returns:
|
||||
`Optional[List[Tuple[str, int]]]`: If `overwrite=False`, returns the list of differences as tuples `(str, int)`
|
||||
with the name of the object having a diff and the line number where theere is the first diff.
|
||||
"""
|
||||
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
|
||||
|
||||
with open(filename, "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
diffs = []
|
||||
|
@ -317,16 +597,31 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||
# There is some copied code here, let's retrieve the original.
|
||||
indent, object_name, replace_pattern = search.groups()
|
||||
|
||||
base_path = TRANSFORMERS_PATH if not filename.startswith("tests") else MODEL_TEST_PATH
|
||||
theoretical_code = find_code_in_transformers(object_name, base_path=base_path)
|
||||
# Find the file lines, the object's code, and its blocks
|
||||
target_lines, theoretical_code, theoretical_code_splits = find_code_and_splits(
|
||||
object_name, base_path, buffer=buffer
|
||||
)
|
||||
|
||||
# code replaced by the patterns
|
||||
theoretical_code_blocks = OrderedDict()
|
||||
for name, start, end in theoretical_code_splits:
|
||||
name = replace_code(name, replace_pattern)
|
||||
code = "".join(target_lines[start:end])
|
||||
code = replace_code(code, replace_pattern)
|
||||
theoretical_code_blocks[name] = code
|
||||
|
||||
theoretical_indent = get_indent(theoretical_code)
|
||||
|
||||
# `start_index` is the index of the first line (the definition header) after `# Copied from`.
|
||||
# (`indent != theoretical_indent` doesn't seem to occur so far, not sure what this case is for.)
|
||||
start_index = line_index + 1 if indent == theoretical_indent else line_index
|
||||
# enter the block body
|
||||
line_index = start_index + 1
|
||||
|
||||
subcode = "\n".join(theoretical_code.split("\n")[1:])
|
||||
indent = get_indent(subcode)
|
||||
# Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment.
|
||||
# We can't call `find_block_end` directly as there is sth. special `# End copy"` here.
|
||||
should_continue = True
|
||||
while line_index < len(lines) and should_continue:
|
||||
line_index += 1
|
||||
|
@ -336,33 +631,118 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||
# There is a special pattern `# End copy` to stop early. It's not documented cause it shouldn't really be
|
||||
# used.
|
||||
should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None
|
||||
# `line_index` is outside the block
|
||||
# Clean up empty lines at the end (if any).
|
||||
while len(lines[line_index - 1]) <= 1:
|
||||
line_index -= 1
|
||||
|
||||
observed_code_lines = lines[start_index:line_index]
|
||||
observed_code = "".join(observed_code_lines)
|
||||
# Split the observed code into blocks
|
||||
observed_code_splits = split_code_into_blocks(lines, start_index, line_index, len(indent), backtrace=True)
|
||||
|
||||
# Before comparing, use the `replace_pattern` on the original code.
|
||||
if len(replace_pattern) > 0:
|
||||
patterns = replace_pattern.replace("with", "").split(",")
|
||||
patterns = [_re_replace_pattern.search(p) for p in patterns]
|
||||
for pattern in patterns:
|
||||
if pattern is None:
|
||||
continue
|
||||
obj1, obj2, option = pattern.groups()
|
||||
theoretical_code = re.sub(obj1, obj2, theoretical_code)
|
||||
if option.strip() == "all-casing":
|
||||
theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code)
|
||||
theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code)
|
||||
is_class = lines[start_index].startswith(f"{' ' * (len(indent) - 4)}class ")
|
||||
# sanity check
|
||||
_sanity_check_splits(theoretical_code_splits, observed_code_splits, is_class=is_class)
|
||||
|
||||
# observed code in a structured way (a dict mapping block names to blocks' code)
|
||||
observed_code_blocks = OrderedDict()
|
||||
for name, start, end in observed_code_splits:
|
||||
code = "".join(lines[start:end])
|
||||
observed_code_blocks[name] = code
|
||||
|
||||
# Below, we change some names in `theoretical_code_blocks` and `observed_code_blocks`. These mappings map the
|
||||
# original names to the modified names: this is used to restore the original order of the code blocks.
|
||||
name_mappings_1 = {k: k for k in theoretical_code_blocks.keys()}
|
||||
name_mappings_2 = {k: k for k in observed_code_blocks.keys()}
|
||||
|
||||
# Update code blocks' name and content:
|
||||
# If `"# Ignore copy"` is found in a block of the observed code:
|
||||
# 1. if it's a block only in the observed code --> add it to the theoretical code.
|
||||
# 2. if it's also in the theoretical code () --> put its content (body) to the corresponding block under the
|
||||
# same name in the theoretical code.
|
||||
# In both cases, we change the name to have a prefix `_ignored_` so we know if we can discard them during the
|
||||
# comparison.
|
||||
ignored_existing_block_index = 0
|
||||
ignored_new_block_index = 0
|
||||
for name in list(observed_code_blocks.keys()):
|
||||
code = observed_code_blocks[name]
|
||||
if "# Ignore copy" in code:
|
||||
if name in theoretical_code_blocks:
|
||||
# in the target --> just copy the content
|
||||
del theoretical_code_blocks[name]
|
||||
theoretical_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
|
||||
name_mappings_1[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
|
||||
|
||||
del observed_code_blocks[name]
|
||||
observed_code_blocks[f"_ignored_existing_block_{ignored_existing_block_index}"] = code
|
||||
name_mappings_2[name] = f"_ignored_existing_block_{ignored_existing_block_index}"
|
||||
ignored_existing_block_index += 1
|
||||
else:
|
||||
# not in the target --> add it
|
||||
theoretical_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||
name_mappings_1[
|
||||
f"_ignored_new_block_{ignored_new_block_index}"
|
||||
] = f"_ignored_new_block_{ignored_new_block_index}"
|
||||
|
||||
del observed_code_blocks[name]
|
||||
observed_code_blocks[f"_ignored_new_block_{ignored_new_block_index}"] = code
|
||||
name_mappings_2[name] = f"_ignored_new_block_{ignored_new_block_index}"
|
||||
ignored_new_block_index += 1
|
||||
|
||||
# Respect the original block order:
|
||||
# 1. in `theoretical_code_blocks`: the new blocks will follow the existing ones
|
||||
# 2. in `observed_code_blocks`: the original order are kept with names modified potentially. This is necessary
|
||||
# to compute the correct `diff_index` if `overwrite=True` and there is a diff.
|
||||
theoretical_code_blocks = {
|
||||
name_mappings_1[orig_name]: theoretical_code_blocks[name_mappings_1[orig_name]]
|
||||
for orig_name in name_mappings_1
|
||||
}
|
||||
observed_code_blocks = {
|
||||
name_mappings_2[orig_name]: observed_code_blocks[name_mappings_2[orig_name]]
|
||||
for orig_name in name_mappings_2
|
||||
}
|
||||
|
||||
# Ignore the blocks specified to be ignored. This is the version used to check if there is a mismatch
|
||||
theoretical_code_blocks_clean = {
|
||||
k: v
|
||||
for k, v in theoretical_code_blocks.items()
|
||||
if not (k.startswith(("_ignored_existing_block_", "_ignored_new_block_")))
|
||||
}
|
||||
theoretical_code = "".join(list(theoretical_code_blocks_clean.values()))
|
||||
|
||||
# stylify `theoretical_code` before compare (this is needed only when `replace_pattern` is not empty)
|
||||
if replace_pattern:
|
||||
theoretical_code = stylify(theoretical_code)
|
||||
# Remove `\n\n` in `theoretical_code` before compare (so no empty line)
|
||||
while "\n\n" in theoretical_code:
|
||||
theoretical_code = theoretical_code.replace("\n\n", "\n")
|
||||
|
||||
# Compute `observed_code` where we don't include any empty line + keep track the line index between the
|
||||
# original/processed `observed_code` so we can have the correct `diff_index`.
|
||||
idx_to_orig_idx_mapping_for_observed_code_lines = {}
|
||||
idx = -1
|
||||
orig_idx = -1
|
||||
observed_code = ""
|
||||
for name, code in observed_code_blocks.items():
|
||||
if code.endswith("\n"):
|
||||
code = code[:-1]
|
||||
for code_line in code.split("\n"):
|
||||
orig_idx += 1
|
||||
if code_line.strip() and not name.startswith(("_ignored_existing_block_", "_ignored_new_block_")):
|
||||
idx += 1
|
||||
observed_code += code_line + "\n"
|
||||
idx_to_orig_idx_mapping_for_observed_code_lines[idx] = orig_idx
|
||||
|
||||
# Test for a diff and act accordingly.
|
||||
diff_index = check_codes_match(observed_code, theoretical_code)
|
||||
if diff_index is not None:
|
||||
# switch to the index in the original `observed_code` (i.e. before removing empty lines)
|
||||
diff_index = idx_to_orig_idx_mapping_for_observed_code_lines[diff_index]
|
||||
diffs.append([object_name, diff_index + start_index + 1])
|
||||
if overwrite:
|
||||
lines = lines[:start_index] + [theoretical_code] + lines[line_index:]
|
||||
# `theoretical_code_to_write` is a single string but may have several lines.
|
||||
theoretical_code_to_write = stylify("".join(list(theoretical_code_blocks.values())))
|
||||
lines = lines[:start_index] + [theoretical_code_to_write] + lines[line_index:]
|
||||
# Here we treat it as a single entry in `lines`.
|
||||
line_index = start_index + 1
|
||||
|
||||
if overwrite and len(diffs) > 0:
|
||||
|
@ -373,7 +753,7 @@ def is_copy_consistent(filename: str, overwrite: bool = False) -> Optional[List[
|
|||
return diffs
|
||||
|
||||
|
||||
def check_copies(overwrite: bool = False):
|
||||
def check_copies(overwrite: bool = False, file: str = None):
|
||||
"""
|
||||
Check every file is copy-consistent with the original. Also check the model list in the main README and other
|
||||
READMEs are consistent.
|
||||
|
@ -381,14 +761,21 @@ def check_copies(overwrite: bool = False):
|
|||
Args:
|
||||
overwrite (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to overwrite the copies when they don't match.
|
||||
file (`bool`, *optional*):
|
||||
The path to a specific file to check and/or fix.
|
||||
"""
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
|
||||
all_files = list(all_files) + list(all_test_files)
|
||||
buffer = {}
|
||||
|
||||
if file is None:
|
||||
all_files = glob.glob(os.path.join(TRANSFORMERS_PATH, "**/*.py"), recursive=True)
|
||||
all_test_files = glob.glob(os.path.join(MODEL_TEST_PATH, "**/*.py"), recursive=True)
|
||||
all_files = list(all_files) + list(all_test_files)
|
||||
else:
|
||||
all_files = [file]
|
||||
|
||||
diffs = []
|
||||
for filename in all_files:
|
||||
new_diffs = is_copy_consistent(filename, overwrite)
|
||||
new_diffs = is_copy_consistent(filename, overwrite, buffer)
|
||||
diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs]
|
||||
if not overwrite and len(diffs) > 0:
|
||||
diff = "\n".join(diffs)
|
||||
|
@ -733,9 +1120,10 @@ def check_readme(overwrite: bool = False):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--file", type=str, default=None, help="A specific file to check and/or fix")
|
||||
parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.")
|
||||
args = parser.parse_args()
|
||||
|
||||
check_readme(args.fix_and_overwrite)
|
||||
check_copies(args.fix_and_overwrite)
|
||||
check_copies(args.fix_and_overwrite, args.file)
|
||||
check_full_copies(args.fix_and_overwrite)
|
||||
|
|
Loading…
Reference in New Issue