Small QOL improvements to TrainingArguments (#7475)

* Small QOL improvements to TrainingArguments

* With the self.
This commit is contained in:
Sylvain Gugger 2020-09-30 12:12:03 -04:00 committed by GitHub
parent dc7d2daa4c
commit a97a73e0ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -49,8 +49,9 @@ class TrainingArguments:
:obj:`output_dir` points to a checkpoint directory.
do_train (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run training or not.
do_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run evaluation on the dev set or not.
do_eval (:obj:`bool`, `optional`):
Whether to run evaluation on the dev set or not. Will default to :obj:`evaluation_strategy` different from
:obj:`"no"`.
do_predict (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to run predictions on the test set or not.
evaluation_strategy(:obj:`str` or :class:`~transformers.trainer_utils.EvaluationStrategy`, `optional`, defaults to :obj:`"no"`):
@ -183,7 +184,7 @@ class TrainingArguments:
)
do_train: bool = field(default=False, metadata={"help": "Whether to run training."})
do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."})
do_eval: bool = field(default=None, metadata={"help": "Whether to run eval on the dev set."})
do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."})
evaluate_during_training: bool = field(
default=None,
@ -333,7 +334,8 @@ class TrainingArguments:
)
else:
self.evaluation_strategy = EvaluationStrategy(self.evaluation_strategy)
if self.do_eval is None:
self.do_eval = self.evaluation_strategy != EvaluationStrategy.NO
if self.eval_steps is None:
self.eval_steps = self.logging_steps
@ -341,6 +343,8 @@ class TrainingArguments:
self.metric_for_best_model = "loss"
if self.greater_is_better is None and self.metric_for_best_model is not None:
self.greater_is_better = self.metric_for_best_model not in ["loss", "eval_loss"]
if self.run_name is None:
self.run_name = self.output_dir
@property
def train_batch_size(self) -> int: