Update the update metadata job to use upload_folder (#23917)

This commit is contained in:
Sylvain Gugger 2023-05-31 14:10:14 -04:00 committed by GitHub
parent 3ff443a6d9
commit 4aa13224a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 37 additions and 31 deletions

View File

@ -21,7 +21,7 @@ import tempfile
import pandas as pd
from datasets import Dataset
from huggingface_hub import Repository
from huggingface_hub import hf_hub_download, upload_folder
from transformers.utils import direct_transformers_import
@ -209,43 +209,49 @@ def update_metadata(token, commit_sha):
"""
Update the metadata for the Transformers repo.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
repo = Repository(tmp_dir, clone_from="huggingface/transformers-metadata", repo_type="dataset", token=token)
frameworks_table = get_frameworks_table()
frameworks_dataset = Dataset.from_pandas(frameworks_table)
frameworks_table = get_frameworks_table()
frameworks_dataset = Dataset.from_pandas(frameworks_table)
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
resolved_tags_file = hf_hub_download(
"huggingface/transformers-metadata", "pipeline_tags.json", repo_type="dataset", token=token
)
tags_dataset = Dataset.from_json(resolved_tags_file)
table = {
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
for i in range(len(tags_dataset))
}
table = update_pipeline_and_auto_class_table(table)
tags_dataset = Dataset.from_json(os.path.join(tmp_dir, "pipeline_tags.json"))
table = {
tags_dataset[i]["model_class"]: (tags_dataset[i]["pipeline_tag"], tags_dataset[i]["auto_class"])
for i in range(len(tags_dataset))
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
model_classes = sorted(table.keys())
tags_table = pd.DataFrame(
{
"model_class": model_classes,
"pipeline_tag": [table[m][0] for m in model_classes],
"auto_class": [table[m][1] for m in model_classes],
}
table = update_pipeline_and_auto_class_table(table)
)
tags_dataset = Dataset.from_pandas(tags_table)
# Sort the model classes to avoid some nondeterministic updates to create false update commits.
model_classes = sorted(table.keys())
tags_table = pd.DataFrame(
{
"model_class": model_classes,
"pipeline_tag": [table[m][0] for m in model_classes],
"auto_class": [table[m][1] for m in model_classes],
}
)
tags_dataset = Dataset.from_pandas(tags_table)
with tempfile.TemporaryDirectory() as tmp_dir:
frameworks_dataset.to_json(os.path.join(tmp_dir, "frameworks.json"))
tags_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
if repo.is_repo_clean():
print("Nothing to commit!")
if commit_sha is not None:
commit_message = (
f"Update with commit {commit_sha}\n\nSee: "
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
)
else:
if commit_sha is not None:
commit_message = (
f"Update with commit {commit_sha}\n\nSee: "
f"https://github.com/huggingface/transformers/commit/{commit_sha}"
)
else:
commit_message = "Update"
repo.push_to_hub(commit_message)
commit_message = "Update"
upload_folder(
repo_id="huggingface/transformers-metadata",
folder_path=tmp_dir,
repo_type="dataset",
token=token,
commit_message=commit_message,
)
def check_pipeline_tags():