Add seed setting to image classification example (#18519)
This commit is contained in:
parent
9129fd0377
commit
88a0ce57bb
|
@ -43,6 +43,7 @@ from transformers import (
|
||||||
HfArgumentParser,
|
HfArgumentParser,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
set_seed,
|
||||||
)
|
)
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
from transformers.utils import check_min_version, send_example_telemetry
|
from transformers.utils import check_min_version, send_example_telemetry
|
||||||
|
@ -214,6 +215,9 @@ def main():
|
||||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Set seed before initializing model.
|
||||||
|
set_seed(training_args.seed)
|
||||||
|
|
||||||
# Initialize our dataset and prepare it for the 'image-classification' task.
|
# Initialize our dataset and prepare it for the 'image-classification' task.
|
||||||
if data_args.dataset_name is not None:
|
if data_args.dataset_name is not None:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
|
|
Loading…
Reference in New Issue