Add seed setting to image classification example (#18519)

This commit is contained in:
regisss 2022-08-08 14:08:11 +02:00 committed by GitHub
parent 9129fd0377
commit 88a0ce57bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -43,6 +43,7 @@ from transformers import (
HfArgumentParser, HfArgumentParser,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
set_seed,
) )
from transformers.trainer_utils import get_last_checkpoint from transformers.trainer_utils import get_last_checkpoint
from transformers.utils import check_min_version, send_example_telemetry from transformers.utils import check_min_version, send_example_telemetry
@ -214,6 +215,9 @@ def main():
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch." "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
) )
# Set seed before initializing model.
set_seed(training_args.seed)
# Initialize our dataset and prepare it for the 'image-classification' task. # Initialize our dataset and prepare it for the 'image-classification' task.
if data_args.dataset_name is not None: if data_args.dataset_name is not None:
dataset = load_dataset( dataset = load_dataset(