Refactor internals for Trainer push_to_hub (#13486)

This commit is contained in:
Sylvain Gugger 2021-09-09 13:04:37 -04:00 committed by GitHub
parent 3dd538c4d3
commit e59d4d0147
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 21 deletions

View File

@ -2238,3 +2238,13 @@ class PushToHubMixin:
commit_message = "add model"
return repo.push_to_hub(commit_message=commit_message)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = HfApi().whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"

View File

@ -51,6 +51,8 @@ from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from huggingface_hub import Repository
from . import __version__
from .configuration_utils import PretrainedConfig
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
from .file_utils import (
CONFIG_NAME,
WEIGHTS_NAME,
PushToHubMixin,
get_full_repo_name,
is_apex_available,
is_datasets_available,
is_in_notebook,
@ -2478,15 +2480,17 @@ class Trainer:
"""
if not self.args.should_save:
return
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
repo_url = PushToHubMixin._get_repo_url_from_name(
self.args.push_to_hub_model_id,
organization=self.args.push_to_hub_organization,
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
if self.args.hub_model_id is None:
repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token)
else:
repo_name = self.args.hub_model_id
self.repo = Repository(
self.args.output_dir,
clone_from=repo_name,
use_auth_token=use_auth_token,
)
self.repo = PushToHubMixin._create_or_get_repo(
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
)
# By default, ignore the checkpoint folders
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
@ -2523,7 +2527,7 @@ class Trainer:
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
"""
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`.
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
Parameters:
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
@ -2536,7 +2540,11 @@ class Trainer:
"""
if self.args.should_save:
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
if self.args.hub_model_id is None:
model_name = Path(self.args.output_dir).name
else:
model_name = self.args.hub_model_id.split("/")[-1]
self.create_model_card(model_name=model_name, **kwargs)
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
# self.args.should_save.
self.save_model()

View File

@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
from .debug_utils import DebugOption
from .file_utils import (
cached_property,
get_full_repo_name,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_available,
@ -335,12 +336,14 @@ class TrainingArguments:
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
details.
push_to_hub_model_id (:obj:`str`, `optional`):
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
Will default to the name of :obj:`output_dir`.
push_to_hub_organization (:obj:`str`, `optional`):
The name of the organization in with to which push the :class:`~transformers.Trainer`.
push_to_hub_token (:obj:`str`, `optional`):
hub_model_id (:obj:`str`, `optional`):
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
of with :obj:`"organization_name/model"`.
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
:obj:`output_dir`.
hub_token (:obj:`str`, `optional`):
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
:obj:`huggingface-cli login`.
"""
@ -612,6 +615,11 @@ class TrainingArguments:
default=None,
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
)
hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
)
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
# Deprecated arguments
push_to_hub_model_id: str = field(
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
)
@ -761,8 +769,40 @@ class TrainingArguments:
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
self.hf_deepspeed_config.trainer_config_process(self)
if self.push_to_hub_model_id is None:
self.push_to_hub_model_id = Path(self.output_dir).name
if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_token` instead.",
FutureWarning,
)
self.hub_token = self.push_to_hub_token
if self.push_to_hub_model_id is not None:
self.hub_model_id = get_full_repo_name(
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
)
if self.push_to_hub_organization is not None:
warnings.warn(
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
f"argument (in this case {self.hub_model_id}).",
FutureWarning,
)
else:
warnings.warn(
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
elif self.push_to_hub_organization is not None:
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
warnings.warn(
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
f"{self.hub_model_id}).",
FutureWarning,
)
def __str__(self):
self_as_dict = asdict(self)

View File

@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer"),
push_to_hub=True,
push_to_hub_token=self._token,
hub_token=self._token,
)
url = trainer.push_to_hub()
@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
trainer = get_regression_trainer(
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
push_to_hub=True,
push_to_hub_organization="valid_org",
push_to_hub_token=self._token,
hub_model_id="valid_org/test-trainer-org",
hub_token=self._token,
)
url = trainer.push_to_hub()