Update Transformers to huggingface_hub >= 0.1.0 (#14251)
* Update Transformers to huggingface_hub >= 0.1.0 * Forgot to save... * Style * Fix test
This commit is contained in:
parent
519a677e87
commit
558f8543ba
|
@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub:
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
model_list = HfApi().list_models()
|
||||
from huggingface_hub import list_models
|
||||
model_list = list_models()
|
||||
org = "Helsinki-NLP"
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||
suffix = [x.split('/')[1] for x in model_ids]
|
||||
|
|
|
@ -14,7 +14,7 @@ import lightning_base
|
|||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
from huggingface_hub import list_models
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
|
@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus):
|
|||
def test_hub_configs(self):
|
||||
"""I put require_torch_gpu cause I only want this to run with self-scheduled."""
|
||||
|
||||
model_list = HfApi().list_models()
|
||||
model_list = list_models()
|
||||
org = "sshleifer"
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)]
|
||||
allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"]
|
||||
|
|
2
setup.py
2
setup.py
|
@ -100,7 +100,7 @@ _deps = [
|
|||
"flax>=0.3.4",
|
||||
"fugashi>=1.0",
|
||||
"GitPython<3.1.19",
|
||||
"huggingface-hub>=0.0.17",
|
||||
"huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata",
|
||||
"ipadic>=1.0.0,<2.0",
|
||||
"isort>=5.5.4",
|
||||
|
|
|
@ -12,14 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from getpass import getpass
|
||||
from typing import List, Union
|
||||
|
||||
from huggingface_hub.hf_api import HfApi, HfFolder
|
||||
from huggingface_hub.hf_api import HfFolder, create_repo, list_repos_objs, login, logout, whoami
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from . import BaseTransformersCLICommand
|
||||
|
@ -142,7 +140,6 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str:
|
|||
class BaseUserCommand:
|
||||
def __init__(self, args):
|
||||
self.args = args
|
||||
self._api = HfApi()
|
||||
|
||||
|
||||
class LoginCommand(BaseUserCommand):
|
||||
|
@ -166,7 +163,7 @@ class LoginCommand(BaseUserCommand):
|
|||
username = input("Username: ")
|
||||
password = getpass()
|
||||
try:
|
||||
token = self._api.login(username, password)
|
||||
token = login(username, password)
|
||||
except HTTPError as e:
|
||||
# probably invalid credentials, display error message.
|
||||
print(e)
|
||||
|
@ -191,7 +188,7 @@ class WhoamiCommand(BaseUserCommand):
|
|||
print("Not logged in")
|
||||
exit()
|
||||
try:
|
||||
user, orgs = self._api.whoami(token)
|
||||
user, orgs = whoami(token)
|
||||
print(user)
|
||||
if orgs:
|
||||
print(ANSI.bold("orgs: "), ",".join(orgs))
|
||||
|
@ -214,7 +211,7 @@ class LogoutCommand(BaseUserCommand):
|
|||
print("Not logged in")
|
||||
exit()
|
||||
HfFolder.delete_token()
|
||||
self._api.logout(token)
|
||||
logout(token)
|
||||
print("Successfully logged out.")
|
||||
|
||||
|
||||
|
@ -222,46 +219,24 @@ class ListObjsCommand(BaseUserCommand):
|
|||
def run(self):
|
||||
print(
|
||||
ANSI.red(
|
||||
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||
"Please use `huggingface-cli` instead."
|
||||
"Command removed: it used to be the way to delete an object on S3."
|
||||
" We now use a git-based system for storing models and other artifacts."
|
||||
" Use list-repo-objs instead"
|
||||
)
|
||||
)
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
objs = self._api.list_objs(token, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
exit(1)
|
||||
if len(objs) == 0:
|
||||
print("No shared file yet")
|
||||
exit()
|
||||
rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs]
|
||||
print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"]))
|
||||
exit(1)
|
||||
|
||||
|
||||
class DeleteObjCommand(BaseUserCommand):
|
||||
def run(self):
|
||||
print(
|
||||
ANSI.red(
|
||||
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||
"Please use `huggingface-cli` instead."
|
||||
"Command removed: it used to be the way to delete an object on S3."
|
||||
" We now use a git-based system for storing models and other artifacts."
|
||||
" Use delete-repo instead"
|
||||
)
|
||||
)
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
exit(1)
|
||||
print("Done")
|
||||
exit(1)
|
||||
|
||||
|
||||
class ListReposObjsCommand(BaseUserCommand):
|
||||
|
@ -277,7 +252,7 @@ class ListReposObjsCommand(BaseUserCommand):
|
|||
print("Not logged in")
|
||||
exit(1)
|
||||
try:
|
||||
objs = self._api.list_repos_objs(token, organization=self.args.organization)
|
||||
objs = list_repos_objs(token, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
|
@ -320,7 +295,7 @@ class RepoCreateCommand(BaseUserCommand):
|
|||
)
|
||||
print("")
|
||||
|
||||
user, _ = self._api.whoami(token)
|
||||
user, _ = whoami(token)
|
||||
namespace = self.args.organization if self.args.organization is not None else user
|
||||
full_name = f"{namespace}/{self.args.name}"
|
||||
print(f"You are about to create {ANSI.bold(full_name)}")
|
||||
|
@ -331,7 +306,7 @@ class RepoCreateCommand(BaseUserCommand):
|
|||
print("Abort")
|
||||
exit()
|
||||
try:
|
||||
url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization)
|
||||
url = create_repo(token, name=self.args.name, organization=self.args.organization)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
|
@ -356,73 +331,12 @@ class DeprecatedUploadCommand(BaseUserCommand):
|
|||
|
||||
|
||||
class UploadCommand(BaseUserCommand):
|
||||
def walk_dir(self, rel_path):
|
||||
"""
|
||||
Recursively list all files in a folder.
|
||||
"""
|
||||
entries: List[os.DirEntry] = list(os.scandir(rel_path))
|
||||
files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename)
|
||||
for f in entries:
|
||||
if f.is_dir():
|
||||
files += self.walk_dir(f.path)
|
||||
return files
|
||||
|
||||
def run(self):
|
||||
print(
|
||||
ANSI.red(
|
||||
"WARNING! Managing repositories through transformers-cli is deprecated. "
|
||||
"Please use `huggingface-cli` instead."
|
||||
"Deprecated: used to be the way to upload a model to S3."
|
||||
" We now use a git-based system for storing models and other artifacts."
|
||||
" Use the `repo create` command instead."
|
||||
)
|
||||
)
|
||||
token = HfFolder.get_token()
|
||||
if token is None:
|
||||
print("Not logged in")
|
||||
exit(1)
|
||||
local_path = os.path.abspath(self.args.path)
|
||||
if os.path.isdir(local_path):
|
||||
if self.args.filename is not None:
|
||||
raise ValueError("Cannot specify a filename override when uploading a folder.")
|
||||
rel_path = os.path.basename(local_path)
|
||||
files = self.walk_dir(rel_path)
|
||||
elif os.path.isfile(local_path):
|
||||
filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path)
|
||||
files = [(local_path, filename)]
|
||||
else:
|
||||
raise ValueError(f"Not a valid file or directory: {local_path}")
|
||||
|
||||
if sys.platform == "win32":
|
||||
files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files]
|
||||
|
||||
if len(files) > UPLOAD_MAX_FILES:
|
||||
print(
|
||||
f"About to upload {ANSI.bold(len(files))} files to S3. This is probably wrong. Please filter files "
|
||||
"before uploading."
|
||||
)
|
||||
exit(1)
|
||||
|
||||
user, _ = self._api.whoami(token)
|
||||
namespace = self.args.organization if self.args.organization is not None else user
|
||||
|
||||
for filepath, filename in files:
|
||||
print(
|
||||
f"About to upload file {ANSI.bold(filepath)} to S3 under filename {ANSI.bold(filename)} and namespace "
|
||||
f"{ANSI.bold(namespace)}"
|
||||
)
|
||||
|
||||
if not self.args.yes:
|
||||
choice = input("Proceed? [Y/n] ").lower()
|
||||
if not (choice == "" or choice == "y" or choice == "yes"):
|
||||
print("Abort")
|
||||
exit()
|
||||
print(ANSI.bold("Uploading... This might take a while if files are large"))
|
||||
for filepath, filename in files:
|
||||
try:
|
||||
access_url = self._api.presign_and_upload(
|
||||
token=token, filename=filename, filepath=filepath, organization=self.args.organization
|
||||
)
|
||||
except HTTPError as e:
|
||||
print(e)
|
||||
print(ANSI.red(e.response.text))
|
||||
exit(1)
|
||||
print("Your file now lives at:")
|
||||
print(access_url)
|
||||
exit(1)
|
||||
|
|
|
@ -18,7 +18,7 @@ deps = {
|
|||
"flax": "flax>=0.3.4",
|
||||
"fugashi": "fugashi>=1.0",
|
||||
"GitPython": "GitPython<3.1.19",
|
||||
"huggingface-hub": "huggingface-hub>=0.0.17",
|
||||
"huggingface-hub": "huggingface-hub>=0.1.0,<1.0",
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"ipadic": "ipadic>=1.0.0,<2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
|
|
|
@ -48,7 +48,7 @@ from tqdm.auto import tqdm
|
|||
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
from huggingface_hub import HfApi, HfFolder, Repository
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami
|
||||
from transformers.utils.versions import importlib_metadata
|
||||
|
||||
from . import __version__
|
||||
|
@ -1808,17 +1808,14 @@ def get_list_of_files(
|
|||
if is_offline_mode() or local_files_only:
|
||||
return []
|
||||
|
||||
# Otherwise we grab the token and use the model_info method.
|
||||
# Otherwise we grab the token and use the list_repo_files method.
|
||||
if isinstance(use_auth_token, str):
|
||||
token = use_auth_token
|
||||
elif use_auth_token is True:
|
||||
token = HfFolder.get_token()
|
||||
else:
|
||||
token = None
|
||||
model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info(
|
||||
path_or_repo, revision=revision, token=token
|
||||
)
|
||||
return [f.rfilename for f in model_info.siblings]
|
||||
return list_repo_files(path_or_repo, revision=revision, token=token)
|
||||
|
||||
|
||||
class cached_property(property):
|
||||
|
@ -2308,7 +2305,7 @@ class PushToHubMixin:
|
|||
token = None
|
||||
|
||||
# Special provision for the test endpoint (CI)
|
||||
return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo(
|
||||
return create_repo(
|
||||
token,
|
||||
repo_name,
|
||||
organization=organization,
|
||||
|
@ -2366,7 +2363,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
|||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = HfApi().whoami(token)["name"]
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
|
|
@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union
|
|||
|
||||
import requests
|
||||
import yaml
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import model_info
|
||||
|
||||
from . import __version__
|
||||
from .file_utils import (
|
||||
|
@ -387,8 +387,8 @@ class TrainingSummary:
|
|||
and len(self.finetuned_from) > 0
|
||||
):
|
||||
try:
|
||||
model_info = HfApi().model_info(self.finetuned_from)
|
||||
for tag in model_info.tags:
|
||||
info = model_info(self.finetuned_from)
|
||||
for tag in info.tags:
|
||||
if tag.startswith("license:"):
|
||||
self.license = tag[8:]
|
||||
except requests.exceptions.HTTPError:
|
||||
|
|
|
@ -27,7 +27,7 @@ import torch
|
|||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
from huggingface_hub.hf_api import list_models
|
||||
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||
|
||||
|
||||
|
@ -64,8 +64,7 @@ def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decod
|
|||
def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]:
|
||||
"""Find models that can accept src_lang as input and return tgt_lang as output."""
|
||||
prefix = "Helsinki-NLP/opus-mt-"
|
||||
api = HfApi()
|
||||
model_list = api.list_models()
|
||||
model_list = list_models()
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")]
|
||||
src_and_targ = [
|
||||
remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m
|
||||
|
|
|
@ -19,11 +19,11 @@ import os
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, GPT2Config, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
|
||||
from transformers.testing_utils import PASS, USER, is_staging_test
|
||||
|
||||
|
||||
config_common_kwargs = {
|
||||
|
@ -194,18 +194,17 @@ class ConfigTester(object):
|
|||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-config")
|
||||
delete_repo(token=cls._token, name="test-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -28,13 +28,12 @@ from typing import Dict, List, Tuple
|
|||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
|
||||
from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
|
@ -2122,23 +2121,22 @@ class FakeModel(PreTrainedModel):
|
|||
class ModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model")
|
||||
delete_repo(token=cls._token, name="test-model")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-model-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-dynamic-model")
|
||||
delete_repo(token=cls._token, name="test-dynamic-model")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -22,19 +22,11 @@ from typing import List, Tuple
|
|||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
)
|
||||
from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
|
@ -627,18 +619,17 @@ class FlaxModelTesterMixin:
|
|||
class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-flax")
|
||||
delete_repo(token=cls._token, name="test-model-flax")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub.hf_api import HfApi
|
||||
from huggingface_hub.hf_api import list_models
|
||||
from transformers import MarianConfig, is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
|
@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase):
|
|||
@slow
|
||||
@require_torch
|
||||
def test_model_names(self):
|
||||
model_list = HfApi().list_models()
|
||||
model_list = list_models()
|
||||
model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)]
|
||||
bad_model_ids = [mid for mid in model_ids if "+" in model_ids]
|
||||
self.assertListEqual([], bad_model_ids)
|
||||
|
|
|
@ -24,12 +24,11 @@ import unittest
|
|||
from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
|
@ -1530,18 +1529,17 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||
class TFModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-tf")
|
||||
delete_repo(token=cls._token, name="test-model-tf")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from collections import OrderedDict
|
|||
from itertools import takewhile
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AlbertTokenizer,
|
||||
|
@ -44,7 +44,6 @@ from transformers import (
|
|||
is_torch_available,
|
||||
)
|
||||
from transformers.testing_utils import (
|
||||
ENDPOINT_STAGING,
|
||||
PASS,
|
||||
USER,
|
||||
get_tests_dir,
|
||||
|
@ -3520,18 +3519,17 @@ class TokenizerPushToHubTester(unittest.TestCase):
|
|||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-tokenizer")
|
||||
delete_repo(token=cls._token, name="test-tokenizer")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from pathlib import Path
|
|||
|
||||
import numpy as np
|
||||
|
||||
from huggingface_hub import HfApi, Repository
|
||||
from huggingface_hub import Repository, delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
|
@ -1307,19 +1307,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||
class TrainerIntegrationWithHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._api = HfApi(endpoint=ENDPOINT_STAGING)
|
||||
cls._token = cls._api.login(username=USER, password=PASS)
|
||||
cls._token = login(username=USER, password=PASS)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]:
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name=model)
|
||||
delete_repo(token=cls._token, name=model)
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
|
||||
delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
|
@ -1396,6 +1395,10 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||
print(commits, len(commits))
|
||||
|
||||
def test_push_to_hub_with_saves_each_n_steps(self):
|
||||
num_gpus = max(1, get_gpu_count())
|
||||
if num_gpus > 2:
|
||||
return
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer-step"),
|
||||
|
@ -1409,7 +1412,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token)
|
||||
commits = self.get_commit_history(tmp_dir)
|
||||
expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)]
|
||||
total_steps = 20 // num_gpus
|
||||
expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)]
|
||||
expected_commits.append("initial commit")
|
||||
self.assertListEqual(commits, expected_commits)
|
||||
print(commits, len(commits))
|
||||
|
|
Loading…
Reference in New Issue