TF model cards (#14720)

* Initial commit for Keras model cards

* Revert accidental change

* make style

* make style

* make style

* Fix PR comments

* Move repo creation to __init__

* Fixes to README.md creation

* Partial progress for proper card creation on `push_to_hub`

* Proper card creation from `push_to_hub` plus fixes for malformed model cards

* Fixes for model card creation outside the callback

* Adding a model card creation test

* Putting the model card creation test in the right file.
Good job, Matt.

* make style

* Fix model card test temp dir usage

* Fix model card creation when no optimizer present

* Fixes for when training history not present

* Fix accidental edit to test_modeling_common
This commit is contained in:
Matt 2021-12-15 14:57:52 +00:00 committed by GitHub
parent 72c6e8b8bf
commit 48d4827697
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 221 additions and 7 deletions

View File

@ -2335,6 +2335,7 @@ class PushToHubMixin:
organization: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
**model_card_kwargs
) -> str:
"""
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
@ -2409,6 +2410,14 @@ class PushToHubMixin:
)
# Save the files in the cloned repo
self.save_pretrained(repo_path_or_name)
if hasattr(self, "history") and hasattr(self, "create_model_card"):
# This is a Keras model and we might be able to fish out its History and make a model card out of it
base_model_card_args = {
"output_dir": repo_path_or_name,
"model_name": Path(repo_path_or_name).name,
}
base_model_card_args.update(model_card_kwargs)
self.create_model_card(**base_model_card_args)
# Commit and push!
url = self._push_to_hub(repo, commit_message=commit_message)

View File

@ -10,6 +10,7 @@ from huggingface_hub import Repository
from . import IntervalStrategy, PreTrainedTokenizerBase
from .file_utils import get_full_repo_name
from .modelcard import TrainingSummary
logger = logging.getLogger(__name__)
@ -25,6 +26,7 @@ class PushToHubCallback(Callback):
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
checkpoint: bool = False,
**model_card_args
):
"""
output_dir (:obj:`str`):
@ -70,12 +72,22 @@ class PushToHubCallback(Callback):
hub_model_id = get_full_repo_name(hub_model_id, token=hub_token)
self.output_dir = output_dir
self.hub_model_id = hub_model_id
self.repo = Repository(
str(output_dir), clone_from=hub_model_id, use_auth_token=hub_token if hub_token else True
str(self.output_dir),
clone_from=self.hub_model_id,
use_auth_token=hub_token if hub_token else True,
)
self.tokenizer = tokenizer
self.last_job = None
self.checkpoint = checkpoint
self.training_history = None
self.model_card_args = model_card_args
def on_train_begin(self, logs=None):
# Although we can access model.history, we have no guarantees that the History callback will fire before this
# one, so we keep track of it here too
self.training_history = []
def on_train_batch_end(self, batch, logs=None):
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
@ -89,6 +101,9 @@ class PushToHubCallback(Callback):
)
def on_epoch_end(self, epoch, logs=None):
if "epoch" not in logs:
logs["epoch"] = epoch
self.training_history.append(logs)
if self.save_strategy == IntervalStrategy.EPOCH:
if self.last_job is not None and not self.last_job.is_done:
return # The last upload is still running, don't start another
@ -98,6 +113,15 @@ class PushToHubCallback(Callback):
if self.checkpoint:
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
self.model._save_checkpoint(checkpoint_dir, epoch)
train_summary = TrainingSummary.from_keras(
model=self.model,
model_name=self.hub_model_id,
keras_history=self.training_history,
**self.model_card_args,
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
_, self.last_job = self.repo.push_to_hub(
commit_message=f"Training in progress epoch {epoch}", blocking=False
)
@ -110,4 +134,10 @@ class PushToHubCallback(Callback):
self.model.save_pretrained(self.output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(self.output_dir)
train_summary = TrainingSummary.from_keras(
model=self.model, model_name=self.hub_model_id, keras_history=self.training_history, **self.model_card_args
)
model_card = train_summary.to_model_card()
with (self.output_dir / "README.md").open("w") as f:
f.write(model_card)
self.repo.push_to_hub(commit_message="End of training", blocking=True)

View File

@ -38,6 +38,7 @@ from .file_utils import (
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
)
@ -266,11 +267,16 @@ class ModelCard:
writer.write(self.to_json_string())
AUTOGENERATED_COMMENT = """
AUTOGENERATED_TRAINER_COMMENT = """
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
"""
AUTOGENERATED_KERAS_COMMENT = """
<!-- This model card has been generated automatically according to the information Keras had access to. You should
probably proofread and complete it, then remove this comment. -->
"""
TASK_TAG_TO_NAME_MAPPING = {
"fill-mask": "Masked Language Modeling",
@ -377,6 +383,7 @@ class TrainingSummary:
eval_results: Optional[Dict[str, float]] = None
eval_lines: Optional[List[str]] = None
hyperparameters: Optional[Dict[str, Any]] = None
source: Optional[str] = "trainer"
def __post_init__(self):
# Infer default license from the checkpoint used, if possible.
@ -410,15 +417,15 @@ class TrainingSummary:
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
}
model_index["results"] = []
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
return model_index
return [model_index]
if len(task_mapping) == 0:
task_mapping = {None: None}
if len(dataset_mapping) == 0:
dataset_mapping = {None: None}
model_index["results"] = []
# One entry per dataset and per task
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
for task_tag, ds_tag in all_possibilities:
@ -471,7 +478,10 @@ class TrainingSummary:
model_card = f"---\n{metadata}---\n"
# Now the model card for realsies.
model_card += AUTOGENERATED_COMMENT
if self.source == "trainer":
model_card += AUTOGENERATED_TRAINER_COMMENT
else:
model_card += AUTOGENERATED_KERAS_COMMENT
model_card += f"\n# {self.model_name}\n\n"
@ -517,10 +527,15 @@ class TrainingSummary:
model_card += "\n### Framework versions\n\n"
model_card += f"- Transformers {__version__}\n"
if is_torch_available():
if self.source == "trainer" and is_torch_available():
import torch
model_card += f"- Pytorch {torch.__version__}\n"
elif self.source == "keras" and is_tf_available():
import tensorflow as tf
model_card += f"- TensorFlow {tf.__version__}\n"
if is_datasets_available():
import datasets
@ -604,6 +619,113 @@ class TrainingSummary:
hyperparameters=hyperparameters,
)
@classmethod
def from_keras(
cls,
model,
model_name,
keras_history=None,
language=None,
license=None,
tags=None,
finetuned_from=None,
tasks=None,
dataset_tags=None,
dataset=None,
dataset_args=None,
):
# Infer default from dataset
if dataset is not None:
if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
default_tag = dataset.builder_name
# Those are not real datasets from the Hub so we exclude them.
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
if dataset_tags is None:
dataset_tags = [default_tag]
if dataset_args is None:
dataset_args = [dataset.config_name]
if dataset is None and dataset_tags is not None:
dataset = dataset_tags
# Infer default finetuned_from
if (
finetuned_from is None
and hasattr(model.config, "_name_or_path")
and not os.path.isdir(model.config._name_or_path)
):
finetuned_from = model.config._name_or_path
# Infer default task tag:
if tasks is None:
model_class_name = model.__class__.__name__
for task, mapping in TASK_MAPPING.items():
if model_class_name in _get_mapping_values(mapping):
tasks = task
# Add `generated_from_keras_callback` to the tags
if tags is None:
tags = ["generated_from_keras_callback"]
elif isinstance(tags, str) and tags != "generated_from_keras_callback":
tags = [tags, "generated_from_keras_callback"]
elif "generated_from_trainer" not in tags:
tags.append("generated_from_keras_callback")
if keras_history is not None:
_, eval_lines, eval_results = parse_keras_history(keras_history)
else:
eval_lines = []
eval_results = dict()
hyperparameters = extract_hyperparameters_from_keras(model)
return cls(
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_args=dataset_args,
eval_results=eval_results,
eval_lines=eval_lines,
hyperparameters=hyperparameters,
source="keras",
)
def parse_keras_history(logs):
"""
Parse the `logs` of either a `tf.keras.History` object returned by `model.fit()` or an accumulated logs `dict`
passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
"""
if hasattr(logs, "history"):
# This looks like a `History` object
logs.history["epoch"] = logs.epoch
logs = logs.history
else:
# Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
lines = []
for i in range(len(logs["epoch"])):
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
values = dict()
for k, v in epoch_dict.items():
if k.startswith("val_"):
k = "validation_" + k[4:]
elif k != "epoch":
k = "train_" + k
splits = k.split("_")
name = " ".join([part.capitalize() for part in splits])
values[name] = v
lines.append(values)
eval_results = lines[-1]
return logs, lines, eval_results
def parse_log_history(log_history):
"""
@ -666,6 +788,19 @@ def parse_log_history(log_history):
return train_log, lines, None
def extract_hyperparameters_from_keras(model):
import tensorflow as tf
hyperparameters = dict()
if hasattr(model, "optimizer") and model.optimizer is not None:
hyperparameters["optimizer"] = model.optimizer.get_config()
else:
hyperparameters["optimizer"] = None
hyperparameters["training_precision"] = tf.keras.mixed_precision.global_policy().name
return hyperparameters
def _maybe_round(v, decimals=4):
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
return f"{v:.{decimals}f}"

View File

@ -47,6 +47,7 @@ from .file_utils import (
is_remote_url,
)
from .generation_tf_utils import TFGenerationMixin
from .modelcard import TrainingSummary
from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tokenization_utils_base import BatchEncoding
from .utils import logging
@ -926,6 +927,36 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
del return_metrics["loss_loss"]
return return_metrics
def create_model_card(
self,
output_dir,
model_name: str,
language: Optional[str] = None,
license: Optional[str] = None,
tags: Optional[str] = None,
finetuned_from: Optional[str] = None,
tasks: Optional[str] = None,
dataset_tags: Optional[Union[str, List[str]]] = None,
dataset: Optional[Union[str, List[str]]] = None,
dataset_args: Optional[Union[str, List[str]]] = None,
):
training_summary = TrainingSummary.from_keras(
self,
keras_history=self.history,
language=language,
license=license,
tags=tags,
model_name=model_name,
finetuned_from=finetuned_from,
tasks=tasks,
dataset_tags=dataset_tags,
dataset=dataset,
dataset_args=dataset_args,
)
model_card = training_summary.to_model_card()
with open(os.path.join(output_dir, "README.md"), "w") as f:
f.write(model_card)
def set_input_embeddings(self, value):
"""
Set model's input embeddings

View File

@ -1386,6 +1386,15 @@ class TFModelPushToHubTester(unittest.TestCase):
models_equal = False
self.assertTrue(models_equal)
def test_push_to_hub_with_model_card(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
model = TFBertModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.push_to_hub(os.path.join(tmp_dir, "test-model-tf"))
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "test-model-card-tf", "README.md")))
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37