Use repo_type instead of deprecated datasets repo IDs (#19202)

* Use repo_type instead of deprecated datasets repo IDs

* Add missing one in doc
This commit is contained in:
Sylvain Gugger 2022-09-26 09:50:48 -04:00 committed by GitHub
parent 216b2f9e80
commit c20b2c7e18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 74 additions and 69 deletions

View File

@ -67,9 +67,9 @@ You'll also want to create a dictionary that maps a label id to a label class wh
>>> import json
>>> from huggingface_hub import cached_download, hf_hub_url
>>> repo_id = "datasets/huggingface/label-files"
>>> repo_id = "huggingface/label-files"
>>> filename = "ade20k-id2label.json"
>>> id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
>>> id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
>>> id2label = {int(k): v for k, v in id2label.items()}
>>> label2id = {v: k for k, v in id2label.items()}
>>> num_labels = len(id2label)

View File

@ -327,12 +327,12 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
if data_args.dataset_name == "scene_parse_150":
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
else:
repo_id = f"datasets/{data_args.dataset_name}"
repo_id = data_args.dataset_name
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: str(k) for k, v in id2label.items()}

View File

@ -387,12 +387,12 @@ def main():
# Prepare label mappings.
# We'll include these in the model's config to get human readable labels in the Inference API.
if args.dataset_name == "scene_parse_150":
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
else:
repo_id = f"datasets/{args.dataset_name}"
repo_id = args.dataset_name
filename = "id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
label2id = {v: k for k, v in id2label.items()}

View File

@ -176,7 +176,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config = BeitConfig()
has_lm_head = False
is_semantic = False
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
# set config parameters based on URL
if checkpoint_url[-9:-4] == "pt22k":
# masked image modeling
@ -188,7 +188,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.use_relative_position_bias = True
config.num_labels = 21841
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
# this dataset contains 21843 labels but the model only has 21841
# we delete the classes as mentioned in https://github.com/google-research/big_transfer/issues/18
@ -201,7 +201,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.use_relative_position_bias = True
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
@ -214,7 +214,7 @@ def convert_beit_checkpoint(checkpoint_url, pytorch_dump_folder_path):
config.use_relative_position_bias = True
config.num_labels = 150
filename = "ade20k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -237,9 +237,9 @@ def convert_conditional_detr_checkpoint(model_name, pytorch_dump_folder_path):
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -62,9 +62,9 @@ def get_convnext_config(checkpoint_url):
filename = "imagenet-22k-id2label.json"
expected_shape = (1, 21841)
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
config.num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
if "1k" not in checkpoint_url:
# this dataset contains 21843 labels but the model only has 21841

View File

@ -282,9 +282,9 @@ def convert_cvt_checkpoint(cvt_model, image_size, cvt_file_name, pytorch_dump_fo
img_labels_file = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, img_labels_file)), "r"))
id2label = json.load(open(cached_download(hf_hub_url(repo_id, img_labels_file, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -282,9 +282,9 @@ def main():
config.use_mean_pooling = True
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -108,9 +108,9 @@ def convert_deformable_detr_checkpoint(
config.two_stage = two_stage
# set labels
config.num_labels = 91
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -140,9 +140,9 @@ def convert_deit_checkpoint(deit_name, pytorch_dump_folder_path):
base_model = False
# dataset (fine-tuned on ImageNet 2012), patch_size and image_size
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -194,9 +194,9 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
config.num_labels = 250
else:
config.num_labels = 91
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -149,9 +149,9 @@ def convert_dit_checkpoint(checkpoint_url, pytorch_dump_folder_path, push_to_hub
# labels
if "rvlcdip" in checkpoint_url:
config.num_labels = 16
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "rvlcdip-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -48,9 +48,9 @@ def get_dpt_config(checkpoint_url):
config.use_batch_norm_in_fusion_residual = True
config.num_labels = 150
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "ade20k-id2label.json"
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -85,9 +85,9 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -62,8 +62,8 @@ def get_mobilevit_config(mobilevit_name):
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
repo_id = "datasets/huggingface/label-files"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
repo_id = "huggingface/label-files"
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -300,7 +300,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec
# load HuggingFace model
config = PerceiverConfig()
subsampling = None
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
if architecture == "MLM":
config.qk_channels = 8 * 32
config.v_channels = 1280
@ -318,7 +318,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec
# set labels
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
@ -367,7 +367,7 @@ def convert_perceiver_checkpoint(pickle_file, pytorch_dump_folder_path, architec
model = PerceiverForMultimodalAutoencoding(config)
# set labels
filename = "kinetics700-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -99,14 +99,14 @@ def convert_poolformer_checkpoint(model_name, checkpoint_path, pytorch_dump_fold
config = PoolFormerConfig()
# set attributes based on model_name
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
size = model_name[-3:]
config.num_labels = 1000
filename = "imagenet-1k-id2label.json"
expected_shape = (1, 1000)
# set config attributes
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -163,9 +163,9 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
filename = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -224,9 +224,9 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename)), "r"))
id2label = json.load(open(cached_download(hf_hub_url(repo_id, filename, repo_type="dataset")), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -128,9 +128,9 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
num_labels = 1000
expected_shape = (1, num_labels)
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -128,7 +128,7 @@ def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folde
encoder_only = False
# set attributes based on model_name
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
if "segformer" in model_name:
size = model_name[len("segformer.") : len("segformer.") + 2]
if "ade" in model_name:
@ -151,7 +151,7 @@ def convert_segformer_checkpoint(model_name, checkpoint_path, pytorch_dump_folde
raise ValueError(f"Model {model_name} not supported")
# set config attributes
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -39,9 +39,9 @@ def get_swin_config(swin_name):
num_classes = 21841
else:
num_classes = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -63,18 +63,18 @@ def get_swinv2_config(swinv2_name):
if ("22k" in swinv2_name) and ("to" not in swinv2_name):
num_classes = 21841
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-22k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
else:
num_classes = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -168,9 +168,9 @@ def convert_weights_and_push(save_directory: Path, model_name: str = None, push_
filename = "imagenet-1k-id2label.json"
num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
num_labels = num_labels
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
id2label = id2label

View File

@ -47,7 +47,7 @@ def get_videomae_config(model_name):
config.use_mean_pooling = False
if "finetuned" in model_name:
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
if "kinetics" in model_name:
config.num_labels = 400
filename = "kinetics400-id2label.json"
@ -56,7 +56,7 @@ def get_videomae_config(model_name):
filename = "something-something-v2-id2label.json"
else:
raise ValueError("Model name should either contain 'kinetics' or 'ssv2' in case it's fine-tuned.")
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}
@ -145,7 +145,9 @@ def convert_state_dict(orig_state_dict, config):
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
def prepare_video():
file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy")
file = hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
)
video = np.load(file)
return list(video)

View File

@ -180,9 +180,9 @@ def convert_vilt_checkpoint(checkpoint_url, pytorch_dump_folder_path):
if "vqa" in checkpoint_url:
vqa_model = True
config.num_labels = 3129
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "vqa2-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -142,9 +142,9 @@ def convert_vit_checkpoint(model_name, pytorch_dump_folder_path, base_model=True
# set labels if required
if not base_model:
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -147,9 +147,9 @@ def convert_vit_checkpoint(vit_name, pytorch_dump_folder_path):
config.image_size = int(vit_name[-9:-6])
else:
config.num_labels = 1000
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "imagenet-1k-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -207,8 +207,9 @@ def prepare_video(num_frames):
elif num_frames == 32:
filename = "eating_spaghetti_32_frames.npy"
file = hf_hub_download(
repo_id="datasets/hf-internal-testing/spaghetti-video",
repo_id="hf-internal-testing/spaghetti-video",
filename=filename,
repo_type="dataset",
)
video = np.load(file)
return list(video)

View File

@ -57,9 +57,9 @@ def get_yolos_config(yolos_name):
config.image_size = [800, 1344]
config.num_labels = 91
repo_id = "datasets/huggingface/label-files"
repo_id = "huggingface/label-files"
filename = "coco-detection-id2label.json"
id2label = json.load(open(hf_hub_download(repo_id, filename), "r"))
id2label = json.load(open(hf_hub_download(repo_id, filename, repo_type="dataset"), "r"))
id2label = {int(k): v for k, v in id2label.items()}
config.id2label = id2label
config.label2id = {v: k for k, v in id2label.items()}

View File

@ -342,7 +342,9 @@ class VideoMAEModelTest(ModelTesterMixin, unittest.TestCase):
# We will verify our results on a video of eating spaghetti
# Frame indices used: [164 168 172 176 181 185 189 193 198 202 206 210 215 219 223 227]
def prepare_video():
file = hf_hub_download(repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy")
file = hf_hub_download(
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti.npy", repo_type="dataset"
)
video = np.load(file)
return list(video)

View File

@ -633,7 +633,7 @@ class XCLIPModelTest(ModelTesterMixin, unittest.TestCase):
# We will verify our results on a spaghetti video
def prepare_video():
file = hf_hub_download(
repo_id="datasets/hf-internal-testing/spaghetti-video", filename="eating_spaghetti_8_frames.npy"
repo_id="hf-internal-testing/spaghetti-video", filename="eating_spaghetti_8_frames.npy", repo_type="dataset"
)
video = np.load(file)
return list(video)