[examples/flax] use Repository API for push_to_hub (#13672)

* use Repository for push_to_hub

* update readme

* update other flax scripts

* update readme

* update qa example

* fix push_to_hub call

* fix typo

* fix more typos

* update readme

* use abosolute path to get repo name

* fix glue script
This commit is contained in:
Suraj Patil 2021-09-30 16:38:07 +05:30 committed by GitHub
parent b90096fe14
commit 7db2a79b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 183 additions and 292 deletions

View File

@ -61,3 +61,14 @@ For a complete overview of models that are supported in JAX/Flax, please have a
Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021.
Click [here](https://huggingface.co/models?filter=jax) to see the full list on the 🤗 hub.
## Upload the trained/fine-tuned model to the Hub
All the example scripts support automatic upload of your final model to the [Model Hub](https://huggingface.co/models) by adding a `--push_to_hub` argument. It will then create a repository with your username slash the name of the folder you are using as `output_dir`. For instance, `"sgugger/test-mrpc"` if your username is `sgugger` and you are working in the folder `~/tmp/test-mrpc`.
To specify a given repository name, use the `--hub_model_id` argument. You will need to specify the whole repository name (including your username), for instance `--hub_model_id sgugger/finetuned-bert-mrpc`. To upload to an organization you are a member of, just use the name of that organization instead of your username: `--hub_model_id huggingface/finetuned-bert-mrpc`.
A few notes on this integration:
- you will need to be logged in to the Hugging Face website locally for it to work, the easiest way to achieve this is to run `huggingface-cli login` and then type your username and password when prompted. You can also pass along your authentication token with the `--hub_token` argument.
- the `output_dir` you pick will either need to be a new folder or a local clone of the distant repository you are using.

View File

@ -33,32 +33,10 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-roberta-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-roberta-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-roberta-base
```
To setup all relevant files for training, let's go into the cloned model directory.
To setup all relevant files for training, let's create a directory.
```bash
cd norwegian-roberta-base
```
Next, let's add a symbolic link to the `run_mlm_flax.py`.
```bash
ln -s ~/transformers/examples/flax/language-modeling/run_mlm_flax.py run_mlm_flax.py
mkdir ./norwegian-roberta-base
```
### Train tokenizer
@ -92,7 +70,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=
])
# Save files to disk
tokenizer.save("./tokenizer.json")
tokenizer.save("./norwegian-roberta-base/tokenizer.json")
```
### Create configuration
@ -105,7 +83,7 @@ in the local model folder:
from transformers import RobertaConfig
config = RobertaConfig.from_pretrained("roberta-base", vocab_size=50265)
config.save_pretrained("./")
config.save_pretrained("./norwegian-roberta-base")
```
Great, we have set up our model repository. During training, we will automatically
@ -116,11 +94,11 @@ push the training logs and model weights to the repo.
Next we can run the example script to pretrain the model:
```bash
./run_mlm_flax.py \
--output_dir="./" \
python run_mlm_flax.py \
--output_dir="./norwegian-roberta-base" \
--model_type="roberta" \
--config_name="./" \
--tokenizer_name="./" \
--config_name="./norwegian-roberta-base" \
--tokenizer_name="./norwegian-roberta-base" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="128" \
@ -157,32 +135,11 @@ in Norwegian on a single TPUv3-8 pod.
The example script uses the 🤗 Datasets library. You can easily customize them to your needs if you need extra processing on your datasets.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-gpt2"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-gpt2
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-gpt2
```
To setup all relevant files for training, let's go into the cloned model directory.
To setup all relevant files for training, let's create a directory.
```bash
cd norwegian-gpt2
```
Next, let's add a symbolic link to the training script `run_clm_flax.py`.
```bash
ln -s ~/transformers/examples/flax/language-modeling/run_clm_flax.py run_clm_flax.py
mkdir ./norwegian-gpt2
```
### Train tokenizer
@ -216,7 +173,7 @@ tokenizer.train_from_iterator(batch_iterator(), vocab_size=50257, min_frequency=
])
# Save files to disk
tokenizer.save("./tokenizer.json")
tokenizer.save("./norwegian-gpt2/tokenizer.json")
```
### Create configuration
@ -229,7 +186,7 @@ in the local model folder:
from transformers import GPT2Config
config = GPT2Config.from_pretrained("gpt2", resid_pdrop=0.0, embd_pdrop=0.0, attn_pdrop=0.0, vocab_size=50257)
config.save_pretrained("./")
config.save_pretrained("./norwegian-gpt2")
```
Great, we have set up our model repository. During training, we will now automatically
@ -240,11 +197,11 @@ push the training logs and model weights to the repo.
Finally, we can run the example script to pretrain the model:
```bash
./run_clm_flax.py \
--output_dir="./" \
python run_clm_flax.py \
--output_dir="./norwegian-gpt2" \
--model_type="gpt2" \
--config_name="./" \
--tokenizer_name="./" \
--config_name="./norwegian-gpt2" \
--tokenizer_name="./norwegian-gpt2" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--do_train --do_eval \
@ -282,30 +239,10 @@ The example script uses the 🤗 Datasets library. You can easily customize them
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"norwegian-t5-base"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create norwegian-t5-base
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/norwegian-t5-base
```
To setup all relevant files for trairing, let's go into the cloned model directory.
To setup all relevant files for trairing, let's create a directory.
```bash
cd norwegian-t5-base
```
Next, let's add a symbolic link to the `run_t5_mlm_flax.py` and `t5_tokenizer_model` scripts.
```bash
ln -s ~/transformers/examples/flax/language-modeling/run_t5_mlm_flax.py run_t5_mlm_flax.py
ln -s ~/transformers/examples/flax/language-modeling/t5_tokenizer_model.py t5_tokenizer_model.py
cd ./norwegian-t5-base
```
### Train tokenizer
@ -351,7 +288,7 @@ tokenizer.train_from_iterator(
)
# Save files to disk
tokenizer.save("./tokenizer.json")
tokenizer.save("./norwegian-t5-base/tokenizer.json")
```
### Create configuration
@ -364,7 +301,7 @@ in the local model folder:
from transformers import T5Config
config = T5Config.from_pretrained("google/t5-v1_1-base", vocab_size=tokenizer.get_vocab_size())
config.save_pretrained("./")
config.save_pretrained("./norwegian-t5-base")
```
Great, we have set up our model repository. During training, we will automatically
@ -375,11 +312,11 @@ push the training logs and model weights to the repo.
Next we can run the example script to pretrain the model:
```bash
./run_t5_mlm_flax.py \
--output_dir="./" \
python run_t5_mlm_flax.py \
--output_dir="./norwegian-t5-base" \
--model_type="t5" \
--config_name="./" \
--tokenizer_name="./" \
--config_name="./norwegian-t5-base" \
--tokenizer_name="./norwegian-t5-base" \
--dataset_name="oscar" \
--dataset_config_name="unshuffled_deduplicated_no" \
--max_seq_length="512" \

View File

@ -43,6 +43,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
@ -54,6 +55,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.testing_utils import CaptureLogger
@ -275,6 +277,16 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -654,12 +666,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
if __name__ == "__main__":

View File

@ -41,6 +41,7 @@ import optax
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@ -54,6 +55,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
@ -308,6 +310,16 @@ if __name__ == "__main__":
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -683,9 +695,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

View File

@ -39,6 +39,7 @@ import optax
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
@ -52,6 +53,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
@ -438,6 +440,16 @@ if __name__ == "__main__":
# Set seed before initializing model.
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -791,9 +803,7 @@ if __name__ == "__main__":
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)

View File

@ -26,31 +26,6 @@ of the script.
The following example fine-tunes BERT on SQuAD:
To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-qa-squad-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-qa-squad-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-qa-squad-test
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_qa.py`.
```bash
export MODEL_DIR="./bert-qa-squad-test"
ln -s ~/transformers/examples/flax/question-answering/run_qa.py run_qa.py
```
```bash
python run_qa.py \
@ -63,7 +38,7 @@ python run_qa.py \
--learning_rate 3e-5 \
--num_train_epochs 2 \
--per_device_train_batch_size 12 \
--output_dir ${MODEL_DIR} \
--output_dir ./bert-qa-squad \
--eval_steps 1000 \
--push_to_hub
```
@ -101,8 +76,9 @@ python run_qa.py \
--num_train_epochs 2 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir /tmp/wwm_uncased_finetuned_squad/ \
--eval_steps 1000
--output_dir ./wwm_uncased_finetuned_squad/ \
--eval_steps 1000 \
--push_to_hub
```
Training with the previously defined hyper-parameters yields the following results:

View File

@ -25,6 +25,7 @@ import sys
import time
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
@ -41,6 +42,7 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
AutoConfig,
AutoTokenizer,
@ -50,6 +52,7 @@ from transformers import (
PreTrainedTokenizerFast,
TrainingArguments,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version
from utils_qa import postprocess_qa_predictions
@ -359,6 +362,16 @@ def main():
transformers.utils.logging.set_verbosity_error()
# endregion
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# region Load Data
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
@ -891,12 +904,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"
# endregion

View File

@ -11,43 +11,12 @@ way which enables simple and efficient model parallelism.
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"bart-base-xsum"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bart-base-xsum
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bart-base-xsum
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd bart-base-xsum
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_summarization_flax.py`.
```bash
export MODEL_DIR="./bart-base-xsum"
ln -s ~/transformers/examples/flax/summarization/run_summarization_flax.py run_summarization_flax.py
```
### Train the model
Next we can run the example script to train the model:
```bash
python run_summarization_flax.py \
--output_dir ${MODEL_DIR} \
--output_dir ./bart-base-xsum \
--model_name_or_path facebook/bart-base \
--tokenizer_name facebook/bart-base \
--dataset_name="xsum" \

View File

@ -42,6 +42,7 @@ from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
@ -52,7 +53,7 @@ from transformers import (
TrainingArguments,
is_tensorboard_available,
)
from transformers.file_utils import is_offline_mode
from transformers.file_utils import get_full_repo_name, is_offline_mode
logger = logging.getLogger(__name__)
@ -333,6 +334,16 @@ def main():
# Set the verbosity to info of the Transformers logger (on main process only):
logger.info(f"Training/evaluation parameters {training_args}")
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -800,12 +811,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
if __name__ == "__main__":

View File

@ -21,47 +21,15 @@ limitations under the License.
Based on the script [`run_flax_glue.py`](https://github.com/huggingface/transformers/blob/master/examples/flax/text-classification/run_flax_glue.py).
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models).
To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-glue-mrpc-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-glue-mrpc-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-glue-mrpc-test
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd bert-glue-mrpc-test
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_flax_glue.py`.
```bash
export TASK_NAME=mrpc
export MODEL_DIR="./bert-glue-mrpc-test"
ln -s ~/transformers/examples/flax/text-classification/run_flax_glue.py run_flax_glue.py
```
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) and can also be used for a
dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file (the script might need some tweaks in that case,
refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
```bash
export TASK_NAME=mrpc
python run_flax_glue.py \
--model_name_or_path bert-base-cased \
--task_name ${TASK_NAME} \
@ -69,7 +37,7 @@ python run_flax_glue.py \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--output_dir ${MODEL_DIR} \
--output_dir ./$TASK_NAME/ \
--push_to_hub
```

View File

@ -20,6 +20,7 @@ import os
import random
import time
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Tuple
import datasets
@ -34,7 +35,9 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import AutoConfig, AutoTokenizer, FlaxAutoModelForSequenceClassification, PretrainedConfig
from transformers.file_utils import get_full_repo_name
logger = logging.getLogger(__name__)
@ -128,6 +131,10 @@ def parse_args():
action="store_true",
help="If passed, model checkpoints and tensorboard logs will be pushed to the hub",
)
parser.add_argument(
"--hub_model_id", type=str, help="The name of the repository to keep in sync with the local `output_dir`."
)
parser.add_argument("--hub_token", type=str, help="The token to use to push to the Model Hub.")
args = parser.parse_args()
# Sanity checks
@ -141,6 +148,9 @@ def parse_args():
extension = args.validation_file.split(".")[-1]
assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
if args.push_to_hub:
assert args.output_dir is not None, "Need an `output_dir` to create a repo when `--push_to_hub` is passed."
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
@ -267,6 +277,14 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).absolute().name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
@ -499,12 +517,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
args.output_dir,
params=params,
push_to_hub=args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch}",
)
model.save_pretrained(args.output_dir, params=params)
tokenizer.save_pretrained(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
if __name__ == "__main__":

View File

@ -22,31 +22,6 @@ It will either run on a datasets hosted on our hub or with your own text files f
The following example fine-tunes BERT on CoNLL-2003:
To begin with it is recommended to create a model repository to save the trained model and logs.
Here we call the model `"bert-ner-conll2003-test"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create bert-ner-conll2003-test
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/bert-ner-conll2003-test
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_flax_ner.py`.
```bash
export MODEL_DIR="./bert-ner-conll2003-test"
ln -s ~/transformers/examples/flax/token-classification/run_flax_ner.py run_flax_ner.py
```
```bash
python run_flax_ner.py \
@ -56,7 +31,7 @@ python run_flax_ner.py \
--learning_rate 2e-5 \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--output_dir ${MODEL_DIR} \
--output_dir ./bert-ner-conll2003 \
--eval_steps 300 \
--push_to_hub
```

View File

@ -21,6 +21,7 @@ import sys
import time
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
@ -37,6 +38,7 @@ from flax.jax_utils import replicate, unreplicate
from flax.metrics import tensorboard
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository
from transformers import (
AutoConfig,
AutoTokenizer,
@ -44,6 +46,7 @@ from transformers import (
HfArgumentParser,
TrainingArguments,
)
from transformers.file_utils import get_full_repo_name
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
@ -304,6 +307,16 @@ def main():
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
# or just provide the name of one of the public datasets for token classification task available on the hub at https://huggingface.co/datasets/
# (the dataset will be downloaded automatically from the datasets Hub).
@ -656,12 +669,10 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(unreplicate(state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of step {cur_step}",
)
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of step {cur_step}", blocking=False)
epochs.desc = f"Epoch ... {epoch + 1}/{num_epochs}"

View File

@ -25,37 +25,6 @@ way which enables simple and efficient model parallelism.
In this example we will train/fine-tune the model on the [imagenette](https://github.com/fastai/imagenette) dataset.
Let's start by creating a model repository to save the trained model and logs.
Here we call the model `"vit-base-patch16-imagenette"`, but you can change the model name as you like.
You can do this either directly on [huggingface.co](https://huggingface.co/new) (assuming that
you are logged in) or via the command line:
```
huggingface-cli repo create vit-base-patch16-imagenette
```
Next we clone the model repository to add the tokenizer and model files.
```
git clone https://huggingface.co/<your-username>/vit-base-patch16-imagenette
```
To ensure that all tensorboard traces will be uploaded correctly, we need to
track them. You can run the following command inside your model repo to do so.
```
cd vit-base-patch16-imagenette
git lfs track "*tfevents*"
```
Great, we have set up our model repository. During training, we will automatically
push the training logs and model weights to the repo.
Next, let's add a symbolic link to the `run_image_classification_flax.py`.
```bash
export MODEL_DIR="./vit-base-patch16-imagenette
ln -s ~/transformers/examples/flax/summarization/run_image_classification_flax.py run_image_classification_flax.py
```
## Prepare the dataset
We will use the [imagenette](https://github.com/fastai/imagenette) dataset to train/fine-tune our model. Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).
@ -86,7 +55,7 @@ Next we can run the example script to fine-tune the model:
```bash
python run_image_classification.py \
--output_dir ${MODEL_DIR} \
--output_dir ./vit-base-patch16-imagenette \
--model_name_or_path google/vit-base-patch16-224-in21k \
--train_dir="imagenette2/train" \
--validation_dir="imagenette2/val" \

View File

@ -42,6 +42,7 @@ from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
@ -52,6 +53,7 @@ from transformers import (
is_tensorboard_available,
set_seed,
)
from transformers.file_utils import get_full_repo_name
logger = logging.getLogger(__name__)
@ -205,6 +207,16 @@ def main():
# set seed for random transforms and torch dataloaders
set_seed(training_args.seed)
# Handle the repository creation
if training_args.push_to_hub:
if training_args.hub_model_id is None:
repo_name = get_full_repo_name(
Path(training_args.output_dir).absolute().name, token=training_args.hub_token
)
else:
repo_name = training_args.hub_model_id
repo = Repository(training_args.output_dir, clone_from=repo_name)
# Initialize datasets and pre-processing transforms
# We use torchvision here for faster pre-processing
# Note that here we are using some default pre-processing, for maximum accuray
@ -455,12 +467,9 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
model.save_pretrained(
training_args.output_dir,
params=params,
push_to_hub=training_args.push_to_hub,
commit_message=f"Saving weights and logs of epoch {epoch+1}",
)
model.save_pretrained(training_args.output_dir, params=params)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
if __name__ == "__main__":