Refactor internals for Trainer push_to_hub (#13486)
This commit is contained in:
parent
3dd538c4d3
commit
e59d4d0147
|
@ -2238,3 +2238,13 @@ class PushToHubMixin:
|
|||
commit_message = "add model"
|
||||
|
||||
return repo.push_to_hub(commit_message=commit_message)
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = HfApi().whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
|
|
@ -51,6 +51,8 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from huggingface_hub import Repository
|
||||
|
||||
from . import __version__
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
|
||||
|
@ -60,7 +62,7 @@ from .dependency_versions_check import dep_version_check
|
|||
from .file_utils import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
PushToHubMixin,
|
||||
get_full_repo_name,
|
||||
is_apex_available,
|
||||
is_datasets_available,
|
||||
is_in_notebook,
|
||||
|
@ -2478,15 +2480,17 @@ class Trainer:
|
|||
"""
|
||||
if not self.args.should_save:
|
||||
return
|
||||
use_auth_token = True if self.args.push_to_hub_token is None else self.args.push_to_hub_token
|
||||
repo_url = PushToHubMixin._get_repo_url_from_name(
|
||||
self.args.push_to_hub_model_id,
|
||||
organization=self.args.push_to_hub_organization,
|
||||
use_auth_token = True if self.args.hub_token is None else self.args.hub_token
|
||||
if self.args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(self.args.output_dir).name, token=self.args.hub_token)
|
||||
else:
|
||||
repo_name = self.args.hub_model_id
|
||||
|
||||
self.repo = Repository(
|
||||
self.args.output_dir,
|
||||
clone_from=repo_name,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
self.repo = PushToHubMixin._create_or_get_repo(
|
||||
self.args.output_dir, repo_url=repo_url, use_auth_token=use_auth_token
|
||||
)
|
||||
|
||||
# By default, ignore the checkpoint folders
|
||||
if not os.path.exists(os.path.join(self.args.output_dir, ".gitignore")):
|
||||
|
@ -2523,7 +2527,7 @@ class Trainer:
|
|||
|
||||
def push_to_hub(self, commit_message: Optional[str] = "add model", **kwargs) -> str:
|
||||
"""
|
||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.push_to_hub_model_id`.
|
||||
Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`.
|
||||
|
||||
Parameters:
|
||||
commit_message (:obj:`str`, `optional`, defaults to :obj:`"add model"`):
|
||||
|
@ -2536,7 +2540,11 @@ class Trainer:
|
|||
"""
|
||||
|
||||
if self.args.should_save:
|
||||
self.create_model_card(model_name=self.args.push_to_hub_model_id, **kwargs)
|
||||
if self.args.hub_model_id is None:
|
||||
model_name = Path(self.args.output_dir).name
|
||||
else:
|
||||
model_name = self.args.hub_model_id.split("/")[-1]
|
||||
self.create_model_card(model_name=model_name, **kwargs)
|
||||
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
|
||||
# self.args.should_save.
|
||||
self.save_model()
|
||||
|
|
|
@ -25,6 +25,7 @@ from typing import Any, Dict, List, Optional
|
|||
from .debug_utils import DebugOption
|
||||
from .file_utils import (
|
||||
cached_property,
|
||||
get_full_repo_name,
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_available,
|
||||
|
@ -335,12 +336,14 @@ class TrainingArguments:
|
|||
:class:`~transformers.Trainer`, it's intended to be used by your training/evaluation scripts instead. See
|
||||
the `example scripts <https://github.com/huggingface/transformers/tree/master/examples>`__ for more
|
||||
details.
|
||||
push_to_hub_model_id (:obj:`str`, `optional`):
|
||||
The name of the repository to which push the :class:`~transformers.Trainer` when :obj:`push_to_hub=True`.
|
||||
Will default to the name of :obj:`output_dir`.
|
||||
push_to_hub_organization (:obj:`str`, `optional`):
|
||||
The name of the organization in with to which push the :class:`~transformers.Trainer`.
|
||||
push_to_hub_token (:obj:`str`, `optional`):
|
||||
hub_model_id (:obj:`str`, `optional`):
|
||||
The name of the repository to keep in sync with the local `output_dir`. Should be the whole repository
|
||||
name, for instance :obj:`"user_name/model"`, which allows you to push to an organization you are a member
|
||||
of with :obj:`"organization_name/model"`.
|
||||
|
||||
Will default to :obj:`user_name/output_dir_name` with `output_dir_name` being the name of
|
||||
:obj:`output_dir`.
|
||||
hub_token (:obj:`str`, `optional`):
|
||||
The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with
|
||||
:obj:`huggingface-cli login`.
|
||||
"""
|
||||
|
@ -612,6 +615,11 @@ class TrainingArguments:
|
|||
default=None,
|
||||
metadata={"help": "The path to a folder with a valid checkpoint for your model."},
|
||||
)
|
||||
hub_model_id: str = field(
|
||||
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
||||
)
|
||||
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
||||
# Deprecated arguments
|
||||
push_to_hub_model_id: str = field(
|
||||
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||
)
|
||||
|
@ -761,8 +769,40 @@ class TrainingArguments:
|
|||
self.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.deepspeed)
|
||||
self.hf_deepspeed_config.trainer_config_process(self)
|
||||
|
||||
if self.push_to_hub_model_id is None:
|
||||
self.push_to_hub_model_id = Path(self.output_dir).name
|
||||
if self.push_to_hub_token is not None:
|
||||
warnings.warn(
|
||||
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||
"`--hub_token` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.hub_token = self.push_to_hub_token
|
||||
|
||||
if self.push_to_hub_model_id is not None:
|
||||
self.hub_model_id = get_full_repo_name(
|
||||
self.push_to_hub_model_id, organization=self.push_to_hub_organization, token=self.hub_token
|
||||
)
|
||||
if self.push_to_hub_organization is not None:
|
||||
warnings.warn(
|
||||
"`--push_to_hub_model_id` and `--push_to_hub_organization` are deprecated and will be removed in "
|
||||
"version 5 of 🤗 Transformers. Use `--hub_model_id` instead and pass the full repo name to this "
|
||||
f"argument (in this case {self.hub_model_id}).",
|
||||
FutureWarning,
|
||||
)
|
||||
else:
|
||||
warnings.warn(
|
||||
"`--push_to_hub_model_id` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||
f"{self.hub_model_id}).",
|
||||
FutureWarning,
|
||||
)
|
||||
elif self.push_to_hub_organization is not None:
|
||||
self.hub_model_id = f"{self.push_to_hub_organization}/{Path(self.output_dir).name}"
|
||||
warnings.warn(
|
||||
"`--push_to_hub_organization` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
|
||||
"`--hub_model_id` instead and pass the full repo name to this argument (in this case "
|
||||
f"{self.hub_model_id}).",
|
||||
FutureWarning,
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
self_as_dict = asdict(self)
|
||||
|
|
|
@ -1299,7 +1299,7 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer"),
|
||||
push_to_hub=True,
|
||||
push_to_hub_token=self._token,
|
||||
hub_token=self._token,
|
||||
)
|
||||
url = trainer.push_to_hub()
|
||||
|
||||
|
@ -1321,8 +1321,8 @@ class TrainerIntegrationWithHubTester(unittest.TestCase):
|
|||
trainer = get_regression_trainer(
|
||||
output_dir=os.path.join(tmp_dir, "test-trainer-org"),
|
||||
push_to_hub=True,
|
||||
push_to_hub_organization="valid_org",
|
||||
push_to_hub_token=self._token,
|
||||
hub_model_id="valid_org/test-trainer-org",
|
||||
hub_token=self._token,
|
||||
)
|
||||
url = trainer.push_to_hub()
|
||||
|
||||
|
|
Loading…
Reference in New Issue