603 lines
23 KiB
Python
603 lines
23 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
# Copyright 2021 The HuggingFace Team All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""
|
|
Pre-training/Fine-tuning ViT for image classification .
|
|
|
|
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
|
https://huggingface.co/models?filter=vit
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import sys
|
|
import time
|
|
import warnings
|
|
from dataclasses import asdict, dataclass, field
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Callable, Optional
|
|
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import optax
|
|
|
|
# for dataset and preprocessing
|
|
import torch
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
from flax import jax_utils
|
|
from flax.jax_utils import pad_shard_unpad, unreplicate
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
|
from huggingface_hub import Repository, create_repo
|
|
from tqdm import tqdm
|
|
|
|
import transformers
|
|
from transformers import (
|
|
CONFIG_MAPPING,
|
|
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
|
AutoConfig,
|
|
FlaxAutoModelForImageClassification,
|
|
HfArgumentParser,
|
|
is_tensorboard_available,
|
|
set_seed,
|
|
)
|
|
from transformers.utils import send_example_telemetry
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.keys())
|
|
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
|
|
|
|
|
@dataclass
|
|
class TrainingArguments:
|
|
output_dir: str = field(
|
|
metadata={"help": "The output directory where the model predictions and checkpoints will be written."},
|
|
)
|
|
overwrite_output_dir: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Overwrite the content of the output directory. "
|
|
"Use this to continue training if output_dir points to a checkpoint directory."
|
|
)
|
|
},
|
|
)
|
|
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
|
|
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
|
|
per_device_train_batch_size: int = field(
|
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for training."}
|
|
)
|
|
per_device_eval_batch_size: int = field(
|
|
default=8, metadata={"help": "Batch size per GPU/TPU core/CPU for evaluation."}
|
|
)
|
|
learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."})
|
|
weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."})
|
|
adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"})
|
|
adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"})
|
|
adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."})
|
|
adafactor: bool = field(default=False, metadata={"help": "Whether or not to replace AdamW by Adafactor."})
|
|
num_train_epochs: float = field(default=3.0, metadata={"help": "Total number of training epochs to perform."})
|
|
warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."})
|
|
logging_steps: int = field(default=500, metadata={"help": "Log every X updates steps."})
|
|
save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."})
|
|
eval_steps: int = field(default=None, metadata={"help": "Run an evaluation every X steps."})
|
|
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
|
push_to_hub: bool = field(
|
|
default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."}
|
|
)
|
|
hub_model_id: str = field(
|
|
default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."}
|
|
)
|
|
hub_token: str = field(default=None, metadata={"help": "The token to use to push to the Model Hub."})
|
|
|
|
def __post_init__(self):
|
|
if self.output_dir is not None:
|
|
self.output_dir = os.path.expanduser(self.output_dir)
|
|
|
|
def to_dict(self):
|
|
"""
|
|
Serializes this instance while replace `Enum` by their values (for JSON serialization support). It obfuscates
|
|
the token values by removing their value.
|
|
"""
|
|
d = asdict(self)
|
|
for k, v in d.items():
|
|
if isinstance(v, Enum):
|
|
d[k] = v.value
|
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum):
|
|
d[k] = [x.value for x in v]
|
|
if k.endswith("_token"):
|
|
d[k] = f"<{k.upper()}>"
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
"""
|
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
|
"""
|
|
|
|
model_name_or_path: Optional[str] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
|
|
)
|
|
},
|
|
)
|
|
model_type: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
|
)
|
|
config_name: Optional[str] = field(
|
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
|
)
|
|
cache_dir: Optional[str] = field(
|
|
default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
|
|
)
|
|
dtype: Optional[str] = field(
|
|
default="float32",
|
|
metadata={
|
|
"help": (
|
|
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
|
" `[float32, float16, bfloat16]`."
|
|
)
|
|
},
|
|
)
|
|
token: str = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
|
|
"generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
|
|
)
|
|
},
|
|
)
|
|
use_auth_token: bool = field(
|
|
default=None,
|
|
metadata={
|
|
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead."
|
|
},
|
|
)
|
|
trust_remote_code: bool = field(
|
|
default=False,
|
|
metadata={
|
|
"help": (
|
|
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option "
|
|
"should only be set to `True` for repositories you trust and in which you have read the code, as it will "
|
|
"execute code present on the Hub on your local machine."
|
|
)
|
|
},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataTrainingArguments:
|
|
"""
|
|
Arguments pertaining to what data we are going to input our model for training and eval.
|
|
"""
|
|
|
|
train_dir: str = field(
|
|
metadata={"help": "Path to the root training directory which contains one subdirectory per class."}
|
|
)
|
|
validation_dir: str = field(
|
|
metadata={"help": "Path to the root validation directory which contains one subdirectory per class."},
|
|
)
|
|
image_size: Optional[int] = field(default=224, metadata={"help": " The size (resolution) of each image."})
|
|
max_train_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
|
"value if set."
|
|
)
|
|
},
|
|
)
|
|
max_eval_samples: Optional[int] = field(
|
|
default=None,
|
|
metadata={
|
|
"help": (
|
|
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
|
"value if set."
|
|
)
|
|
},
|
|
)
|
|
overwrite_cache: bool = field(
|
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
|
)
|
|
preprocessing_num_workers: Optional[int] = field(
|
|
default=None,
|
|
metadata={"help": "The number of processes to use for the preprocessing."},
|
|
)
|
|
|
|
|
|
class TrainState(train_state.TrainState):
|
|
dropout_rng: jnp.ndarray
|
|
|
|
def replicate(self):
|
|
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
|
|
|
|
|
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
summary_writer.scalar("train_time", train_time, step)
|
|
|
|
train_metrics = get_metrics(train_metrics)
|
|
for key, vals in train_metrics.items():
|
|
tag = f"train_{key}"
|
|
for i, val in enumerate(vals):
|
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
|
|
|
for metric_name, value in eval_metrics.items():
|
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
|
|
|
|
|
def create_learning_rate_fn(
|
|
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
|
) -> Callable[[int], jnp.ndarray]:
|
|
"""Returns a linear warmup, linear_decay learning rate function."""
|
|
steps_per_epoch = train_ds_size // train_batch_size
|
|
num_train_steps = steps_per_epoch * num_train_epochs
|
|
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
|
decay_fn = optax.linear_schedule(
|
|
init_value=learning_rate, end_value=0, transition_steps=num_train_steps - num_warmup_steps
|
|
)
|
|
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
|
return schedule_fn
|
|
|
|
|
|
def main():
|
|
# See all possible arguments in src/transformers/training_args.py
|
|
# or by passing the --help flag to this script.
|
|
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
|
# If we pass only one argument to the script and it's the path to a json file,
|
|
# let's parse it to get our arguments.
|
|
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
|
else:
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
if model_args.use_auth_token is not None:
|
|
warnings.warn(
|
|
"The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token` instead.",
|
|
FutureWarning,
|
|
)
|
|
if model_args.token is not None:
|
|
raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
|
|
model_args.token = model_args.use_auth_token
|
|
|
|
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
|
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
|
send_example_telemetry("run_image_classification", model_args, data_args, framework="flax")
|
|
|
|
if (
|
|
os.path.exists(training_args.output_dir)
|
|
and os.listdir(training_args.output_dir)
|
|
and training_args.do_train
|
|
and not training_args.overwrite_output_dir
|
|
):
|
|
raise ValueError(
|
|
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
|
"Use --overwrite_output_dir to overcome."
|
|
)
|
|
|
|
# Make one log on every process with the configuration for debugging.
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%m/%d/%Y %H:%M:%S",
|
|
level=logging.INFO,
|
|
)
|
|
# Setup logging, we only want one process per machine to log things on the screen.
|
|
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
|
if jax.process_index() == 0:
|
|
transformers.utils.logging.set_verbosity_info()
|
|
else:
|
|
transformers.utils.logging.set_verbosity_error()
|
|
|
|
# Set the verbosity to info of the Transformers logger (on main process only):
|
|
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
|
# set seed for random transforms and torch dataloaders
|
|
set_seed(training_args.seed)
|
|
|
|
# Handle the repository creation
|
|
if training_args.push_to_hub:
|
|
# Retrieve of infer repo_name
|
|
repo_name = training_args.hub_model_id
|
|
if repo_name is None:
|
|
repo_name = Path(training_args.output_dir).absolute().name
|
|
# Create repo and retrieve repo_id
|
|
repo_id = create_repo(repo_name, exist_ok=True, token=training_args.hub_token).repo_id
|
|
# Clone repo locally
|
|
repo = Repository(training_args.output_dir, clone_from=repo_id, token=training_args.hub_token)
|
|
|
|
# Initialize datasets and pre-processing transforms
|
|
# We use torchvision here for faster pre-processing
|
|
# Note that here we are using some default pre-processing, for maximum accuracy
|
|
# one should tune this part and carefully select what transformations to use.
|
|
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
|
train_dataset = torchvision.datasets.ImageFolder(
|
|
data_args.train_dir,
|
|
transforms.Compose(
|
|
[
|
|
transforms.RandomResizedCrop(data_args.image_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
]
|
|
),
|
|
)
|
|
|
|
eval_dataset = torchvision.datasets.ImageFolder(
|
|
data_args.validation_dir,
|
|
transforms.Compose(
|
|
[
|
|
transforms.Resize(data_args.image_size),
|
|
transforms.CenterCrop(data_args.image_size),
|
|
transforms.ToTensor(),
|
|
normalize,
|
|
]
|
|
),
|
|
)
|
|
|
|
# Load pretrained model and tokenizer
|
|
if model_args.config_name:
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.config_name,
|
|
num_labels=len(train_dataset.classes),
|
|
image_size=data_args.image_size,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
elif model_args.model_name_or_path:
|
|
config = AutoConfig.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
num_labels=len(train_dataset.classes),
|
|
image_size=data_args.image_size,
|
|
cache_dir=model_args.cache_dir,
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
else:
|
|
config = CONFIG_MAPPING[model_args.model_type]()
|
|
logger.warning("You are instantiating a new config instance from scratch.")
|
|
|
|
if model_args.model_name_or_path:
|
|
model = FlaxAutoModelForImageClassification.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
config=config,
|
|
seed=training_args.seed,
|
|
dtype=getattr(jnp, model_args.dtype),
|
|
token=model_args.token,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
else:
|
|
model = FlaxAutoModelForImageClassification.from_config(
|
|
config,
|
|
seed=training_args.seed,
|
|
dtype=getattr(jnp, model_args.dtype),
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
)
|
|
|
|
# Store some constant
|
|
num_epochs = int(training_args.num_train_epochs)
|
|
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
|
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
|
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
total_train_steps = steps_per_epoch * num_epochs
|
|
|
|
def collate_fn(examples):
|
|
pixel_values = torch.stack([example[0] for example in examples])
|
|
labels = torch.tensor([example[1] for example in examples])
|
|
|
|
batch = {"pixel_values": pixel_values, "labels": labels}
|
|
batch = {k: v.numpy() for k, v in batch.items()}
|
|
|
|
return batch
|
|
|
|
# Create data loaders
|
|
train_loader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
batch_size=train_batch_size,
|
|
shuffle=True,
|
|
num_workers=data_args.preprocessing_num_workers,
|
|
persistent_workers=True,
|
|
drop_last=True,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
eval_loader = torch.utils.data.DataLoader(
|
|
eval_dataset,
|
|
batch_size=eval_batch_size,
|
|
shuffle=False,
|
|
num_workers=data_args.preprocessing_num_workers,
|
|
persistent_workers=True,
|
|
drop_last=False,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
# Enable tensorboard only on the master node
|
|
has_tensorboard = is_tensorboard_available()
|
|
if has_tensorboard and jax.process_index() == 0:
|
|
try:
|
|
from flax.metrics.tensorboard import SummaryWriter
|
|
|
|
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
|
except ImportError as ie:
|
|
has_tensorboard = False
|
|
logger.warning(
|
|
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Unable to display metrics through TensorBoard because the package is not installed: "
|
|
"Please run pip install tensorboard to enable."
|
|
)
|
|
|
|
# Initialize our training
|
|
rng = jax.random.PRNGKey(training_args.seed)
|
|
rng, dropout_rng = jax.random.split(rng)
|
|
|
|
# Create learning rate schedule
|
|
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
|
len(train_dataset),
|
|
train_batch_size,
|
|
training_args.num_train_epochs,
|
|
training_args.warmup_steps,
|
|
training_args.learning_rate,
|
|
)
|
|
|
|
# create adam optimizer
|
|
adamw = optax.adamw(
|
|
learning_rate=linear_decay_lr_schedule_fn,
|
|
b1=training_args.adam_beta1,
|
|
b2=training_args.adam_beta2,
|
|
eps=training_args.adam_epsilon,
|
|
weight_decay=training_args.weight_decay,
|
|
)
|
|
|
|
# Setup train state
|
|
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng)
|
|
|
|
def loss_fn(logits, labels):
|
|
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1]))
|
|
return loss.mean()
|
|
|
|
# Define gradient update step fn
|
|
def train_step(state, batch):
|
|
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
|
|
|
def compute_loss(params):
|
|
labels = batch.pop("labels")
|
|
logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
|
|
loss = loss_fn(logits, labels)
|
|
return loss
|
|
|
|
grad_fn = jax.value_and_grad(compute_loss)
|
|
loss, grad = grad_fn(state.params)
|
|
grad = jax.lax.pmean(grad, "batch")
|
|
|
|
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
|
|
|
metrics = {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}
|
|
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
|
|
return new_state, metrics
|
|
|
|
# Define eval fn
|
|
def eval_step(params, batch):
|
|
labels = batch.pop("labels")
|
|
logits = model(**batch, params=params, train=False)[0]
|
|
loss = loss_fn(logits, labels)
|
|
|
|
# summarize metrics
|
|
accuracy = (jnp.argmax(logits, axis=-1) == labels).mean()
|
|
metrics = {"loss": loss, "accuracy": accuracy}
|
|
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
|
return metrics
|
|
|
|
# Create parallel version of the train and eval step
|
|
p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
|
|
p_eval_step = jax.pmap(eval_step, "batch")
|
|
|
|
# Replicate the train state on each device
|
|
state = state.replicate()
|
|
|
|
logger.info("***** Running training *****")
|
|
logger.info(f" Num examples = {len(train_dataset)}")
|
|
logger.info(f" Num Epochs = {num_epochs}")
|
|
logger.info(f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}")
|
|
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
|
logger.info(f" Total optimization steps = {total_train_steps}")
|
|
|
|
train_time = 0
|
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
|
for epoch in epochs:
|
|
# ======================== Training ================================
|
|
train_start = time.time()
|
|
|
|
# Create sampling rng
|
|
rng, input_rng = jax.random.split(rng)
|
|
train_metrics = []
|
|
|
|
steps_per_epoch = len(train_dataset) // train_batch_size
|
|
train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False)
|
|
# train
|
|
for batch in train_loader:
|
|
batch = shard(batch)
|
|
state, train_metric = p_train_step(state, batch)
|
|
train_metrics.append(train_metric)
|
|
|
|
train_step_progress_bar.update(1)
|
|
|
|
train_time += time.time() - train_start
|
|
|
|
train_metric = unreplicate(train_metric)
|
|
|
|
train_step_progress_bar.close()
|
|
epochs.write(
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
|
f" {train_metric['learning_rate']})"
|
|
)
|
|
|
|
# ======================== Evaluating ==============================
|
|
eval_metrics = []
|
|
eval_steps = len(eval_dataset) // eval_batch_size
|
|
eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False)
|
|
for batch in eval_loader:
|
|
# Model forward
|
|
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
|
state.params, batch, min_device_batch=per_device_eval_batch_size
|
|
)
|
|
eval_metrics.append(metrics)
|
|
|
|
eval_step_progress_bar.update(1)
|
|
|
|
# normalize eval metrics
|
|
eval_metrics = get_metrics(eval_metrics)
|
|
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
|
|
|
# Print metrics and update progress bar
|
|
eval_step_progress_bar.close()
|
|
desc = (
|
|
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {round(eval_metrics['loss'].item(), 4)} | "
|
|
f"Eval Accuracy: {round(eval_metrics['accuracy'].item(), 4)})"
|
|
)
|
|
epochs.write(desc)
|
|
epochs.desc = desc
|
|
|
|
# Save metrics
|
|
if has_tensorboard and jax.process_index() == 0:
|
|
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
|
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
|
|
|
# save checkpoint after each epoch and push checkpoint to the hub
|
|
if jax.process_index() == 0:
|
|
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
|
model.save_pretrained(training_args.output_dir, params=params)
|
|
if training_args.push_to_hub:
|
|
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|