Automate check for new pipelines and metadata update (#19029)
* Automate check for new pipelines and metadata update * Add Datasets to quality extra
This commit is contained in:
parent
0efbb6e93e
commit
3774010161
|
@ -982,6 +982,7 @@ jobs:
|
|||
- run: python utils/check_config_docstrings.py
|
||||
- run: make deps_table_check_updated
|
||||
- run: python utils/tests_fetcher.py --sanity_check
|
||||
- run: python utils/update_metadata.py --check-only
|
||||
|
||||
run_tests_layoutlmv2_and_v3:
|
||||
working_directory: ~/transformers
|
||||
|
|
1
Makefile
1
Makefile
|
@ -41,6 +41,7 @@ repo-consistency:
|
|||
python utils/check_inits.py
|
||||
python utils/check_config_docstrings.py
|
||||
python utils/tests_fetcher.py --sanity_check
|
||||
python utils/update_metadata.py --check-only
|
||||
|
||||
# this target runs checks on all files
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -307,7 +307,7 @@ extras["testing"] = (
|
|||
|
||||
extras["deepspeed-testing"] = extras["deepspeed"] + extras["testing"] + extras["optuna"]
|
||||
|
||||
extras["quality"] = deps_list("black", "isort", "flake8", "GitPython", "hf-doc-builder")
|
||||
extras["quality"] = deps_list("black", "datasets", "isort", "flake8", "GitPython", "hf-doc-builder")
|
||||
|
||||
extras["all"] = (
|
||||
extras["tf"]
|
||||
|
|
|
@ -85,6 +85,12 @@ PIPELINE_TAGS_AND_AUTO_MODELS = [
|
|||
"MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForDocumentQuestionAnswering",
|
||||
),
|
||||
(
|
||||
"visual-question-answering",
|
||||
"MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES",
|
||||
"AutoModelForVisualQuestionAnswering",
|
||||
),
|
||||
("image-to-text", "MODEL_FOR_FOR_VISION_2_SEQ_MAPPING_NAMES", "AutoModelForVision2Seq"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -236,10 +242,35 @@ def update_metadata(token, commit_sha):
|
|||
repo.push_to_hub(commit_message)
|
||||
|
||||
|
||||
def check_pipeline_tags():
|
||||
in_table = {tag: cls for tag, _, cls in PIPELINE_TAGS_AND_AUTO_MODELS}
|
||||
pipeline_tasks = transformers_module.pipelines.SUPPORTED_TASKS
|
||||
missing = []
|
||||
for key in pipeline_tasks:
|
||||
if key not in in_table:
|
||||
model = pipeline_tasks[key]["pt"]
|
||||
if isinstance(model, (list, tuple)):
|
||||
model = model[0]
|
||||
model = model.__name__
|
||||
if model not in in_table.values():
|
||||
missing.append(key)
|
||||
|
||||
if len(missing) > 0:
|
||||
msg = ", ".join(missing)
|
||||
raise ValueError(
|
||||
"The following pipeline tags are not present in the `PIPELINE_TAGS_AND_AUTO_MODELS` constant inside "
|
||||
f"`utils/update_metadata.py`: {msg}. Please add them!"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--token", type=str, help="The token to use to push to the transformers-metadata dataset.")
|
||||
parser.add_argument("--commit_sha", type=str, help="The sha of the commit going with this update.")
|
||||
parser.add_argument("--check-only", action="store_true", help="Activate to just check all pipelines are present.")
|
||||
args = parser.parse_args()
|
||||
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
if args.check_only:
|
||||
check_pipeline_tags()
|
||||
else:
|
||||
update_metadata(args.token, args.commit_sha)
|
||||
|
|
Loading…
Reference in New Issue