From d0b83fe2e1d3df2521908c9a32f51c8742a29c53 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Thu, 6 Apr 2023 18:08:14 +0200 Subject: [PATCH] A script to add/update `pipeline_model_mapping` systematically (#22180) * Auto. add and update pipeline_model_mapping * Fix style and quality * Finalize (comments) * fix --------- Co-authored-by: ydshieh --- utils/add_pipeline_model_mapping_to_test.py | 337 ++++++++++++++++++++ 1 file changed, 337 insertions(+) create mode 100644 utils/add_pipeline_model_mapping_to_test.py diff --git a/utils/add_pipeline_model_mapping_to_test.py b/utils/add_pipeline_model_mapping_to_test.py new file mode 100644 index 0000000000..ebefcff9af --- /dev/null +++ b/utils/add_pipeline_model_mapping_to_test.py @@ -0,0 +1,337 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A script to add and/or update the attribute `pipeline_model_mapping` in model test files. + +This script will be (mostly) used in the following 2 situations: + + - run within a (scheduled) CI job to: + - check if model test files in the library have updated `pipeline_model_mapping`, + - and/or update test files and (possibly) open a GitHub pull request automatically + - being run by a `transformers` member to quickly check and update some particular test file(s) + +This script is **NOT** intended to be run (manually) by community contributors. +""" + + +import argparse +import glob +import inspect +import os +import re +import unittest + +from get_test_info import get_test_classes + +from tests.test_pipeline_mixin import pipeline_test_mapping + + +PIPELINE_TEST_MAPPING = {} +for task, _ in pipeline_test_mapping.items(): + PIPELINE_TEST_MAPPING[task] = {"pt": None, "tf": None} + + +# DO **NOT** add item to this set (unless the reason is approved) +TEST_FILE_TO_IGNORE = { + "tests/models/esm/test_modeling_esmfold.py", # The pipeline test mapping is added to `test_modeling_esm.py` +} + + +def get_framework(test_class): + """Infer the framework from the test class `test_class`.""" + + if "ModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: + return "pt" + elif "TFModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: + return "tf" + elif "FlaxModelTesterMixin" in [x.__name__ for x in test_class.__bases__]: + return "flax" + else: + return None + + +def get_mapping_for_task(task, framework): + """Get mappings defined in `XXXPipelineTests` for the task `task`.""" + # Use the cached results + if PIPELINE_TEST_MAPPING[task].get(framework, None) is not None: + return PIPELINE_TEST_MAPPING[task][framework] + + pipeline_test_class = pipeline_test_mapping[task]["test"] + mapping = None + + if framework == "pt": + mapping = getattr(pipeline_test_class, "model_mapping", None) + elif framework == "tf": + mapping = getattr(pipeline_test_class, "tf_model_mapping", None) + + if mapping is not None: + mapping = dict(mapping.items()) + + # cache the results + PIPELINE_TEST_MAPPING[task][framework] = mapping + return mapping + + +def get_model_for_pipeline_test(test_class, task): + """Get the model architecture(s) related to the test class `test_class` for a pipeline `task`.""" + framework = get_framework(test_class) + if framework is None: + return None + mapping = get_mapping_for_task(task, framework) + if mapping is None: + return None + + config_classes = list({model_class.config_class for model_class in test_class.all_model_classes}) + if len(config_classes) != 1: + raise ValueError("There should be exactly one configuration class from `test_class.all_model_classes`.") + + # This could be a list/tuple of model classes, but it's rare. + model_class = mapping.get(config_classes[0], None) + if isinstance(model_class, (tuple, list)): + model_class = sorted(model_class, key=lambda x: x.__name__) + + return model_class + + +def get_pipeline_model_mapping(test_class): + """Get `pipeline_model_mapping` for `test_class`.""" + mapping = [(task, get_model_for_pipeline_test(test_class, task)) for task in pipeline_test_mapping] + mapping = sorted([(task, model) for task, model in mapping if model is not None], key=lambda x: x[0]) + + return dict(mapping) + + +def get_pipeline_model_mapping_string(test_class): + """Get `pipeline_model_mapping` for `test_class` as a string (to be added to the test file). + + This will be a 1-line string. After this is added to a test file, `make style` will format it beautifully. + """ + framework = get_framework(test_class) + if framework == "pt": + framework = "torch" + default_value = "{}" + + mapping = get_pipeline_model_mapping(test_class) + if len(mapping) == 0: + return "" + + texts = [] + for task, model_classes in mapping.items(): + if isinstance(model_classes, (tuple, list)): + # A list/tuple of model classes + value = "(" + ", ".join([x.__name__ for x in model_classes]) + ")" + else: + # A single model class + value = model_classes.__name__ + texts.append(f'"{task}": {value}') + text = "{" + ", ".join(texts) + "}" + text = f"pipeline_model_mapping = {text} if is_{framework}_available() else {default_value}" + + return text + + +def is_valid_test_class(test_class): + """Restrict to `XXXModelTesterMixin` and should be a subclass of `unittest.TestCase`.""" + base_class_names = {"ModelTesterMixin", "TFModelTesterMixin", "FlaxModelTesterMixin"} + if not issubclass(test_class, unittest.TestCase): + return False + return len(base_class_names.intersection([x.__name__ for x in test_class.__bases__])) > 0 + + +def find_test_class(test_file): + """Find a test class in `test_file` to which we will add `pipeline_model_mapping`.""" + test_classes = [x for x in get_test_classes(test_file) if is_valid_test_class(x)] + + target_test_class = None + for test_class in test_classes: + # If a test class has defined `pipeline_model_mapping`, let's take it + if getattr(test_class, "pipeline_model_mapping", None) is not None: + target_test_class = test_class + break + # Take the test class with the shortest name (just a heuristic) + if target_test_class is None and len(test_classes) > 0: + target_test_class = sorted(test_classes, key=lambda x: (len(x.__name__), x.__name__))[0] + + return target_test_class + + +def find_block_ending(lines, start_idx, indent_level): + end_idx = start_idx + for idx, line in enumerate(lines[start_idx:]): + indent = len(line) - len(line.lstrip()) + if idx == 0 or indent > indent_level or (indent == indent_level and line.strip() == ")"): + end_idx = start_idx + idx + elif idx > 0 and indent <= indent_level: + # Outside the definition block of `pipeline_model_mapping` + break + + return end_idx + + +def add_pipeline_model_mapping(test_class, overwrite=False): + """Add `pipeline_model_mapping` to `test_class`.""" + if getattr(test_class, "pipeline_model_mapping", None) is not None: + if not overwrite: + return "", -1 + + line_to_add = get_pipeline_model_mapping_string(test_class) + if len(line_to_add) == 0: + return "", -1 + line_to_add = line_to_add + "\n" + + # The code defined the class `test_class` + class_lines, class_start_line_no = inspect.getsourcelines(test_class) + # `inspect` gives the code for an object, including decorator(s) if any. + # We (only) need the exact line of the class definition. + for idx, line in enumerate(class_lines): + if line.lstrip().startswith("class "): + class_lines = class_lines[idx:] + class_start_line_no += idx + break + class_end_line_no = class_start_line_no + len(class_lines) - 1 + + # The index in `class_lines` that starts the definition of `all_model_classes`, `all_generative_model_classes` or + # `pipeline_model_mapping`. This assumes they are defined in such order, and we take the start index of the last + # block that appears in a `test_class`. + start_idx = None + # The indent level of the line at `class_lines[start_idx]` (if defined) + indent_level = 0 + # To record if `pipeline_model_mapping` is found in `test_class`. + def_line = None + for idx, line in enumerate(class_lines): + if line.strip().startswith("all_model_classes = "): + indent_level = len(line) - len(line.lstrip()) + start_idx = idx + elif line.strip().startswith("all_generative_model_classes = "): + indent_level = len(line) - len(line.lstrip()) + start_idx = idx + elif line.strip().startswith("pipeline_model_mapping = "): + indent_level = len(line) - len(line.lstrip()) + start_idx = idx + def_line = line + break + + if start_idx is None: + return "", -1 + # Find the ending index (inclusive) of the above found block. + end_idx = find_block_ending(class_lines, start_idx, indent_level) + + # Extract `is_xxx_available()` from existing blocks: some models require specific libraries like `timm` and use + # `is_timm_available()` instead of `is_torch_available()`. + # Keep leading and trailing whitespaces + r = re.compile(r"\s(is_\S+?_available\(\))\s") + for line in class_lines[start_idx : end_idx + 1]: + backend_condition = r.search(line) + if backend_condition is not None: + # replace the leading and trailing whitespaces to the space character " ". + target = " " + backend_condition[0][1:-1] + " " + line_to_add = r.sub(target, line_to_add) + break + + if def_line is None: + # `pipeline_model_mapping` is not defined. The target index is set to the ending index (inclusive) of + # `all_model_classes` or `all_generative_model_classes`. + target_idx = end_idx + else: + # `pipeline_model_mapping` is defined. The target index is set to be one **BEFORE** its start index. + target_idx = start_idx - 1 + # mark the lines of the currently existing `pipeline_model_mapping` to be removed. + for idx in range(start_idx, end_idx + 1): + # These lines are going to be removed before writing to the test file. + class_lines[idx] = None # noqa + + # Make sure the test class is a subclass of `PipelineTesterMixin`. + parent_classes = [x.__name__ for x in test_class.__bases__] + if "PipelineTesterMixin" not in parent_classes: + # Put `PipelineTesterMixin` just before `unittest.TestCase` + _parent_classes = [x for x in parent_classes if x != "TestCase"] + ["PipelineTesterMixin"] + if "TestCase" in parent_classes: + # Here we **assume** the original string is always with `unittest.TestCase`. + _parent_classes.append("unittest.TestCase") + parent_classes = ", ".join(_parent_classes) + for idx, line in enumerate(class_lines): + # Find the ending of the declaration of `test_class` + if line.strip().endswith("):"): + # mark the lines of the declaration of `test_class` to be removed + for _idx in range(idx + 1): + class_lines[_idx] = None # noqa + break + # Add the new, one-line, class declaration for `test_class` + class_lines[0] = f"class {test_class.__name__}({parent_classes}):\n" + + # Add indentation + line_to_add = " " * indent_level + line_to_add + # Insert `pipeline_model_mapping` to `class_lines`. + # (The line at `target_idx` should be kept by definition!) + class_lines = class_lines[: target_idx + 1] + [line_to_add] + class_lines[target_idx + 1 :] + # Remove the lines that are marked to be removed + class_lines = [x for x in class_lines if x is not None] + + # Move from test class to module (in order to write to the test file) + module_lines = inspect.getsourcelines(inspect.getmodule(test_class))[0] + # Be careful with the 1-off between line numbers and array indices + module_lines = module_lines[: class_start_line_no - 1] + class_lines + module_lines[class_end_line_no:] + code = "".join(module_lines) + + moddule_file = inspect.getsourcefile(test_class) + with open(moddule_file, "w", encoding="UTF-8", newline="\n") as fp: + fp.write(code) + + return line_to_add + + +def add_pipeline_model_mapping_to_test_file(test_file, overwrite=False): + """Add `pipeline_model_mapping` to `test_file`.""" + test_class = find_test_class(test_file) + if test_class: + add_pipeline_model_mapping(test_class, overwrite=overwrite) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--test_file", type=str, help="A path to the test file, starting with the repository's `tests` directory." + ) + parser.add_argument( + "--all", + action="store_true", + help="If to check and modify all test files.", + ) + parser.add_argument( + "--overwrite", + action="store_true", + help="If to overwrite a test class if it has already defined `pipeline_model_mapping`.", + ) + args = parser.parse_args() + + if not args.all and not args.test_file: + raise ValueError("Please specify either `test_file` or pass `--all` to check/modify all test files.") + elif args.all and args.test_file: + raise ValueError("Only one of `--test_file` and `--all` could be specified.") + + test_files = [] + if args.test_file: + test_files = [args.test_file] + else: + pattern = os.path.join("tests", "models", "**", "test_modeling_*.py") + for test_file in glob.glob(pattern): + # `Flax` is not concerned at this moment + if not test_file.startswith("test_modeling_flax_"): + test_files.append(test_file) + + for test_file in test_files: + if test_file in TEST_FILE_TO_IGNORE: + print(f"[SKIPPED] {test_file} is skipped as it is in `TEST_FILE_TO_IGNORE` in the file {__file__}.") + continue + add_pipeline_model_mapping_to_test_file(test_file, overwrite=args.overwrite)