From 558f8543ba3860c736a7a9a4176ac20f23f9d5a0 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 2 Nov 2021 18:58:42 -0400 Subject: [PATCH] Update Transformers to huggingface_hub >= 0.1.0 (#14251) * Update Transformers to huggingface_hub >= 0.1.0 * Forgot to save... * Style * Fix test --- docs/source/model_doc/marian.rst | 4 +- .../_test_seq2seq_examples.py | 4 +- setup.py | 2 +- src/transformers/commands/user.py | 124 +++--------------- src/transformers/dependency_versions_table.py | 2 +- src/transformers/file_utils.py | 13 +- src/transformers/modelcard.py | 6 +- .../marian/convert_marian_to_pytorch.py | 5 +- tests/test_configuration_common.py | 11 +- tests/test_modeling_common.py | 12 +- tests/test_modeling_flax_common.py | 19 +-- tests/test_modeling_marian.py | 4 +- tests/test_modeling_tf_common.py | 10 +- tests/test_tokenization_common.py | 10 +- tests/test_trainer.py | 16 ++- 15 files changed, 70 insertions(+), 172 deletions(-) diff --git a/docs/source/model_doc/marian.rst b/docs/source/model_doc/marian.rst index 2f52c696d1..b461f0a9ba 100644 --- a/docs/source/model_doc/marian.rst +++ b/docs/source/model_doc/marian.rst @@ -103,8 +103,8 @@ Here is the code to see all available pretrained models on the hub: .. code-block:: python - from huggingface_hub.hf_api import HfApi - model_list = HfApi().list_models() + from huggingface_hub import list_models + model_list = list_models() org = "Helsinki-NLP" model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] suffix = [x.split('/')[1] for x in model_ids] diff --git a/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py b/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py index bb2c9222c0..d97c9d43b3 100644 --- a/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py +++ b/examples/research_projects/seq2seq-distillation/_test_seq2seq_examples.py @@ -14,7 +14,7 @@ import lightning_base from convert_pl_checkpoint_to_hf import convert_pl_to_hf from distillation import distill_main from finetune import SummarizationModule, main -from huggingface_hub.hf_api import HfApi +from huggingface_hub import list_models from parameterized import parameterized from run_eval import generate_summaries_or_translations from transformers import AutoConfig, AutoModelForSeq2SeqLM @@ -130,7 +130,7 @@ class TestSummarizationDistiller(TestCasePlus): def test_hub_configs(self): """I put require_torch_gpu cause I only want this to run with self-scheduled.""" - model_list = HfApi().list_models() + model_list = list_models() org = "sshleifer" model_ids = [x.modelId for x in model_list if x.modelId.startswith(org)] allowed_to_be_broken = ["sshleifer/blenderbot-3B", "sshleifer/blenderbot-90M"] diff --git a/setup.py b/setup.py index 2b9a19ba57..a3100d6415 100644 --- a/setup.py +++ b/setup.py @@ -100,7 +100,7 @@ _deps = [ "flax>=0.3.4", "fugashi>=1.0", "GitPython<3.1.19", - "huggingface-hub>=0.0.17", + "huggingface-hub>=0.1.0,<1.0", "importlib_metadata", "ipadic>=1.0.0,<2.0", "isort>=5.5.4", diff --git a/src/transformers/commands/user.py b/src/transformers/commands/user.py index 6f690c3edc..a3919b4cb1 100644 --- a/src/transformers/commands/user.py +++ b/src/transformers/commands/user.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import subprocess -import sys from argparse import ArgumentParser from getpass import getpass from typing import List, Union -from huggingface_hub.hf_api import HfApi, HfFolder +from huggingface_hub.hf_api import HfFolder, create_repo, list_repos_objs, login, logout, whoami from requests.exceptions import HTTPError from . import BaseTransformersCLICommand @@ -142,7 +140,6 @@ def tabulate(rows: List[List[Union[str, int]]], headers: List[str]) -> str: class BaseUserCommand: def __init__(self, args): self.args = args - self._api = HfApi() class LoginCommand(BaseUserCommand): @@ -166,7 +163,7 @@ class LoginCommand(BaseUserCommand): username = input("Username: ") password = getpass() try: - token = self._api.login(username, password) + token = login(username, password) except HTTPError as e: # probably invalid credentials, display error message. print(e) @@ -191,7 +188,7 @@ class WhoamiCommand(BaseUserCommand): print("Not logged in") exit() try: - user, orgs = self._api.whoami(token) + user, orgs = whoami(token) print(user) if orgs: print(ANSI.bold("orgs: "), ",".join(orgs)) @@ -214,7 +211,7 @@ class LogoutCommand(BaseUserCommand): print("Not logged in") exit() HfFolder.delete_token() - self._api.logout(token) + logout(token) print("Successfully logged out.") @@ -222,46 +219,24 @@ class ListObjsCommand(BaseUserCommand): def run(self): print( ANSI.red( - "WARNING! Managing repositories through transformers-cli is deprecated. " - "Please use `huggingface-cli` instead." + "Command removed: it used to be the way to delete an object on S3." + " We now use a git-based system for storing models and other artifacts." + " Use list-repo-objs instead" ) ) - token = HfFolder.get_token() - if token is None: - print("Not logged in") - exit(1) - try: - objs = self._api.list_objs(token, organization=self.args.organization) - except HTTPError as e: - print(e) - print(ANSI.red(e.response.text)) - exit(1) - if len(objs) == 0: - print("No shared file yet") - exit() - rows = [[obj.filename, obj.LastModified, obj.ETag, obj.Size] for obj in objs] - print(tabulate(rows, headers=["Filename", "LastModified", "ETag", "Size"])) + exit(1) class DeleteObjCommand(BaseUserCommand): def run(self): print( ANSI.red( - "WARNING! Managing repositories through transformers-cli is deprecated. " - "Please use `huggingface-cli` instead." + "Command removed: it used to be the way to delete an object on S3." + " We now use a git-based system for storing models and other artifacts." + " Use delete-repo instead" ) ) - token = HfFolder.get_token() - if token is None: - print("Not logged in") - exit(1) - try: - self._api.delete_obj(token, filename=self.args.filename, organization=self.args.organization) - except HTTPError as e: - print(e) - print(ANSI.red(e.response.text)) - exit(1) - print("Done") + exit(1) class ListReposObjsCommand(BaseUserCommand): @@ -277,7 +252,7 @@ class ListReposObjsCommand(BaseUserCommand): print("Not logged in") exit(1) try: - objs = self._api.list_repos_objs(token, organization=self.args.organization) + objs = list_repos_objs(token, organization=self.args.organization) except HTTPError as e: print(e) print(ANSI.red(e.response.text)) @@ -320,7 +295,7 @@ class RepoCreateCommand(BaseUserCommand): ) print("") - user, _ = self._api.whoami(token) + user, _ = whoami(token) namespace = self.args.organization if self.args.organization is not None else user full_name = f"{namespace}/{self.args.name}" print(f"You are about to create {ANSI.bold(full_name)}") @@ -331,7 +306,7 @@ class RepoCreateCommand(BaseUserCommand): print("Abort") exit() try: - url = self._api.create_repo(token, name=self.args.name, organization=self.args.organization) + url = create_repo(token, name=self.args.name, organization=self.args.organization) except HTTPError as e: print(e) print(ANSI.red(e.response.text)) @@ -356,73 +331,12 @@ class DeprecatedUploadCommand(BaseUserCommand): class UploadCommand(BaseUserCommand): - def walk_dir(self, rel_path): - """ - Recursively list all files in a folder. - """ - entries: List[os.DirEntry] = list(os.scandir(rel_path)) - files = [(os.path.join(os.getcwd(), f.path), f.path) for f in entries if f.is_file()] # (filepath, filename) - for f in entries: - if f.is_dir(): - files += self.walk_dir(f.path) - return files - def run(self): print( ANSI.red( - "WARNING! Managing repositories through transformers-cli is deprecated. " - "Please use `huggingface-cli` instead." + "Deprecated: used to be the way to upload a model to S3." + " We now use a git-based system for storing models and other artifacts." + " Use the `repo create` command instead." ) ) - token = HfFolder.get_token() - if token is None: - print("Not logged in") - exit(1) - local_path = os.path.abspath(self.args.path) - if os.path.isdir(local_path): - if self.args.filename is not None: - raise ValueError("Cannot specify a filename override when uploading a folder.") - rel_path = os.path.basename(local_path) - files = self.walk_dir(rel_path) - elif os.path.isfile(local_path): - filename = self.args.filename if self.args.filename is not None else os.path.basename(local_path) - files = [(local_path, filename)] - else: - raise ValueError(f"Not a valid file or directory: {local_path}") - - if sys.platform == "win32": - files = [(filepath, filename.replace(os.sep, "/")) for filepath, filename in files] - - if len(files) > UPLOAD_MAX_FILES: - print( - f"About to upload {ANSI.bold(len(files))} files to S3. This is probably wrong. Please filter files " - "before uploading." - ) - exit(1) - - user, _ = self._api.whoami(token) - namespace = self.args.organization if self.args.organization is not None else user - - for filepath, filename in files: - print( - f"About to upload file {ANSI.bold(filepath)} to S3 under filename {ANSI.bold(filename)} and namespace " - f"{ANSI.bold(namespace)}" - ) - - if not self.args.yes: - choice = input("Proceed? [Y/n] ").lower() - if not (choice == "" or choice == "y" or choice == "yes"): - print("Abort") - exit() - print(ANSI.bold("Uploading... This might take a while if files are large")) - for filepath, filename in files: - try: - access_url = self._api.presign_and_upload( - token=token, filename=filename, filepath=filepath, organization=self.args.organization - ) - except HTTPError as e: - print(e) - print(ANSI.red(e.response.text)) - exit(1) - print("Your file now lives at:") - print(access_url) + exit(1) diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 9168ad5c05..786e5a5691 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -18,7 +18,7 @@ deps = { "flax": "flax>=0.3.4", "fugashi": "fugashi>=1.0", "GitPython": "GitPython<3.1.19", - "huggingface-hub": "huggingface-hub>=0.0.17", + "huggingface-hub": "huggingface-hub>=0.1.0,<1.0", "importlib_metadata": "importlib_metadata", "ipadic": "ipadic>=1.0.0,<2.0", "isort": "isort>=5.5.4", diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 9ed77a0327..3b24b4c737 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -48,7 +48,7 @@ from tqdm.auto import tqdm import requests from filelock import FileLock -from huggingface_hub import HfApi, HfFolder, Repository +from huggingface_hub import HfFolder, Repository, create_repo, list_repo_files, whoami from transformers.utils.versions import importlib_metadata from . import __version__ @@ -1808,17 +1808,14 @@ def get_list_of_files( if is_offline_mode() or local_files_only: return [] - # Otherwise we grab the token and use the model_info method. + # Otherwise we grab the token and use the list_repo_files method. if isinstance(use_auth_token, str): token = use_auth_token elif use_auth_token is True: token = HfFolder.get_token() else: token = None - model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( - path_or_repo, revision=revision, token=token - ) - return [f.rfilename for f in model_info.siblings] + return list_repo_files(path_or_repo, revision=revision, token=token) class cached_property(property): @@ -2308,7 +2305,7 @@ class PushToHubMixin: token = None # Special provision for the test endpoint (CI) - return HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).create_repo( + return create_repo( token, repo_name, organization=organization, @@ -2366,7 +2363,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: if token is None: token = HfFolder.get_token() if organization is None: - username = HfApi().whoami(token)["name"] + username = whoami(token)["name"] return f"{username}/{model_id}" else: return f"{organization}/{model_id}" diff --git a/src/transformers/modelcard.py b/src/transformers/modelcard.py index 52f385fd07..6eb8382cdf 100644 --- a/src/transformers/modelcard.py +++ b/src/transformers/modelcard.py @@ -25,7 +25,7 @@ from typing import Any, Dict, List, Optional, Union import requests import yaml -from huggingface_hub import HfApi +from huggingface_hub import model_info from . import __version__ from .file_utils import ( @@ -387,8 +387,8 @@ class TrainingSummary: and len(self.finetuned_from) > 0 ): try: - model_info = HfApi().model_info(self.finetuned_from) - for tag in model_info.tags: + info = model_info(self.finetuned_from) + for tag in info.tags: if tag.startswith("license:"): self.license = tag[8:] except requests.exceptions.HTTPError: diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py index 7e49a36ad6..6603916254 100644 --- a/src/transformers/models/marian/convert_marian_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_to_pytorch.py @@ -27,7 +27,7 @@ import torch from torch import nn from tqdm import tqdm -from huggingface_hub.hf_api import HfApi +from huggingface_hub.hf_api import list_models from transformers import MarianConfig, MarianMTModel, MarianTokenizer @@ -64,8 +64,7 @@ def load_layers_(layer_lst: nn.ModuleList, opus_state: dict, converter, is_decod def find_pretrained_model(src_lang: str, tgt_lang: str) -> List[str]: """Find models that can accept src_lang as input and return tgt_lang as output.""" prefix = "Helsinki-NLP/opus-mt-" - api = HfApi() - model_list = api.list_models() + model_list = list_models() model_ids = [x.modelId for x in model_list if x.modelId.startswith("Helsinki-NLP")] src_and_targ = [ remove_prefix(m, prefix).lower().split("-") for m in model_ids if "+" not in m diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index 760a860d1b..da675e45ef 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -19,11 +19,11 @@ import os import tempfile import unittest -from huggingface_hub import HfApi +from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import BertConfig, GPT2Config, is_torch_available from transformers.configuration_utils import PretrainedConfig -from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test +from transformers.testing_utils import PASS, USER, is_staging_test config_common_kwargs = { @@ -194,18 +194,17 @@ class ConfigTester(object): class ConfigPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): try: - cls._api.delete_repo(token=cls._token, name="test-config") + delete_repo(token=cls._token, name="test-config") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org") + delete_repo(token=cls._token, name="test-config-org", organization="valid_org") except HTTPError: pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a5ba9d2989..9dfe9275fd 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -28,13 +28,12 @@ from typing import Dict, List, Tuple import numpy as np import transformers -from huggingface_hub import HfApi, Repository +from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( - ENDPOINT_STAGING, PASS, USER, CaptureLogger, @@ -2122,23 +2121,22 @@ class FakeModel(PreTrainedModel): class ModelPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): try: - cls._api.delete_repo(token=cls._token, name="test-model") + delete_repo(token=cls._token, name="test-model") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-model-org", organization="valid_org") + delete_repo(token=cls._token, name="test-model-org", organization="valid_org") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-dynamic-model") + delete_repo(token=cls._token, name="test-dynamic-model") except HTTPError: pass diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 4e5acbfa65..228084a0df 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -22,19 +22,11 @@ from typing import List, Tuple import numpy as np import transformers -from huggingface_hub import HfApi +from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import BertConfig, is_flax_available, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import ( - ENDPOINT_STAGING, - PASS, - USER, - CaptureLogger, - is_pt_flax_cross_test, - is_staging_test, - require_flax, -) +from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax from transformers.utils import logging @@ -627,18 +619,17 @@ class FlaxModelTesterMixin: class FlaxModelPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): try: - cls._api.delete_repo(token=cls._token, name="test-model-flax") + delete_repo(token=cls._token, name="test-model-flax") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org") + delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org") except HTTPError: pass diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 07da2d7ff0..82d4742509 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -17,7 +17,7 @@ import tempfile import unittest -from huggingface_hub.hf_api import HfApi +from huggingface_hub.hf_api import list_models from transformers import MarianConfig, is_torch_available from transformers.file_utils import cached_property from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device @@ -296,7 +296,7 @@ class ModelManagementTests(unittest.TestCase): @slow @require_torch def test_model_names(self): - model_list = HfApi().list_models() + model_list = list_models() model_ids = [x.modelId for x in model_list if x.modelId.startswith(ORG_NAME)] bad_model_ids = [mid for mid in model_ids if "+" in model_ids] self.assertListEqual([], bad_model_ids) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 2709b4ff76..64ca24eeb6 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -24,12 +24,11 @@ import unittest from importlib import import_module from typing import List, Tuple -from huggingface_hub import HfApi +from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import is_tf_available from transformers.models.auto import get_values from transformers.testing_utils import ( - ENDPOINT_STAGING, PASS, USER, CaptureLogger, @@ -1530,18 +1529,17 @@ class UtilsFunctionsTest(unittest.TestCase): class TFModelPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): try: - cls._api.delete_repo(token=cls._token, name="test-model-tf") + delete_repo(token=cls._token, name="test-model-tf") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org") + delete_repo(token=cls._token, name="test-model-tf-org", organization="valid_org") except HTTPError: pass diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 4495467a57..d733d2d4b0 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -27,7 +27,7 @@ from collections import OrderedDict from itertools import takewhile from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union -from huggingface_hub import HfApi +from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import ( AlbertTokenizer, @@ -44,7 +44,6 @@ from transformers import ( is_torch_available, ) from transformers.testing_utils import ( - ENDPOINT_STAGING, PASS, USER, get_tests_dir, @@ -3520,18 +3519,17 @@ class TokenizerPushToHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): try: - cls._api.delete_repo(token=cls._token, name="test-tokenizer") + delete_repo(token=cls._token, name="test-tokenizer") except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org") + delete_repo(token=cls._token, name="test-tokenizer-org", organization="valid_org") except HTTPError: pass diff --git a/tests/test_trainer.py b/tests/test_trainer.py index b1e6e0f5a9..0f1b6a5ff6 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -26,7 +26,7 @@ from pathlib import Path import numpy as np -from huggingface_hub import HfApi, Repository +from huggingface_hub import Repository, delete_repo, login from requests.exceptions import HTTPError from transformers import ( AutoTokenizer, @@ -1307,19 +1307,18 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): class TrainerIntegrationWithHubTester(unittest.TestCase): @classmethod def setUpClass(cls): - cls._api = HfApi(endpoint=ENDPOINT_STAGING) - cls._token = cls._api.login(username=USER, password=PASS) + cls._token = login(username=USER, password=PASS) @classmethod def tearDownClass(cls): for model in ["test-trainer", "test-trainer-epoch", "test-trainer-step"]: try: - cls._api.delete_repo(token=cls._token, name=model) + delete_repo(token=cls._token, name=model) except HTTPError: pass try: - cls._api.delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org") + delete_repo(token=cls._token, name="test-trainer-org", organization="valid_org") except HTTPError: pass @@ -1396,6 +1395,10 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): print(commits, len(commits)) def test_push_to_hub_with_saves_each_n_steps(self): + num_gpus = max(1, get_gpu_count()) + if num_gpus > 2: + return + with tempfile.TemporaryDirectory() as tmp_dir: trainer = get_regression_trainer( output_dir=os.path.join(tmp_dir, "test-trainer-step"), @@ -1409,7 +1412,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase): with tempfile.TemporaryDirectory() as tmp_dir: _ = Repository(tmp_dir, clone_from=f"{USER}/test-trainer-step", use_auth_token=self._token) commits = self.get_commit_history(tmp_dir) - expected_commits = [f"Training in progress, step {i}" for i in range(20, 0, -5)] + total_steps = 20 // num_gpus + expected_commits = [f"Training in progress, step {i}" for i in range(total_steps, 0, -5)] expected_commits.append("initial commit") self.assertListEqual(commits, expected_commits) print(commits, len(commits))