Improve pytorch examples for fp16 (#9796)
* Pad to 8x for fp16 multiple choice example (#9752) * Pad to 8x for fp16 squad trainer example (#9752) * Pad to 8x for fp16 ner example (#9752) * Pad to 8x for fp16 swag example (#9752) * Pad to 8x for fp16 qa beam search example (#9752) * Pad to 8x for fp16 qa example (#9752) * Pad to 8x for fp16 seq2seq example (#9752) * Pad to 8x for fp16 glue example (#9752) * Pad to 8x for fp16 new ner example (#9752) * update script template #9752 * Update examples/multiple-choice/run_swag.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update examples/question-answering/run_qa_beam_search.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * improve code quality #9752 Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
781e4b1384
commit
10e5f28212
|
@ -28,6 +28,7 @@ from transformers import (
|
|||
AutoConfig,
|
||||
AutoModelForMultipleChoice,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
|
@ -188,6 +189,9 @@ def main():
|
|||
preds = np.argmax(p.predictions, axis=1)
|
||||
return {"acc": simple_accuracy(preds, p.label_ids)}
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
@ -195,6 +199,7 @@ def main():
|
|||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
|
|
@ -23,7 +23,14 @@ from dataclasses import dataclass, field
|
|||
from typing import Optional
|
||||
|
||||
import transformers
|
||||
from transformers import AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer, HfArgumentParser, SquadDataset
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForQuestionAnswering,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
HfArgumentParser,
|
||||
SquadDataset,
|
||||
)
|
||||
from transformers import SquadDataTrainingArguments as DataTrainingArguments
|
||||
from transformers import Trainer, TrainingArguments
|
||||
from transformers.trainer_utils import is_main_process
|
||||
|
@ -145,12 +152,16 @@ def main():
|
|||
else None
|
||||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
|
|
@ -30,6 +30,7 @@ from transformers import (
|
|||
AutoConfig,
|
||||
AutoModelForTokenClassification,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
|
@ -237,6 +238,9 @@ def main():
|
|||
"f1": f1_score(out_label_list, preds_list),
|
||||
}
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) if training_args.fp16 else None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
@ -244,6 +248,7 @@ def main():
|
|||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
|
|
@ -316,7 +316,9 @@ def main():
|
|||
|
||||
# Data collator
|
||||
data_collator = (
|
||||
default_data_collator if data_args.pad_to_max_length else DataCollatorForMultipleChoice(tokenizer=tokenizer)
|
||||
default_data_collator
|
||||
if data_args.pad_to_max_length
|
||||
else DataCollatorForMultipleChoice(tokenizer=tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
)
|
||||
|
||||
# Metric
|
||||
|
|
|
@ -411,7 +411,11 @@ def main():
|
|||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
# collator.
|
||||
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
|
||||
data_collator = (
|
||||
default_data_collator
|
||||
if data_args.pad_to_max_length
|
||||
else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions):
|
||||
|
|
|
@ -448,7 +448,11 @@ def main():
|
|||
# Data collator
|
||||
# We have already padded to max length if the corresponding flag is True, otherwise we need to pad in the data
|
||||
# collator.
|
||||
data_collator = default_data_collator if data_args.pad_to_max_length else DataCollatorWithPadding(tokenizer)
|
||||
data_collator = (
|
||||
default_data_collator
|
||||
if data_args.pad_to_max_length
|
||||
else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
)
|
||||
|
||||
# Post-processing:
|
||||
def post_processing_function(examples, features, predictions):
|
||||
|
|
|
@ -437,7 +437,11 @@ def main():
|
|||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
else:
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer, label_pad_token_id=label_pad_token_id)
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
label_pad_token_id=label_pad_token_id,
|
||||
pad_to_multiple_of=8 if training_args.fp16 else None,
|
||||
)
|
||||
|
||||
# Metric
|
||||
metric_name = "rouge" if data_args.task.startswith("summarization") else "sacrebleu"
|
||||
|
|
|
@ -30,6 +30,7 @@ from transformers import (
|
|||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
EvalPrediction,
|
||||
HfArgumentParser,
|
||||
PretrainedConfig,
|
||||
|
@ -375,6 +376,14 @@ def main():
|
|||
else:
|
||||
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
||||
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
|
||||
if data_args.pad_to_max_length:
|
||||
data_collator = default_data_collator
|
||||
elif training_args.fp16:
|
||||
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
else:
|
||||
data_collator = None
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
|
@ -383,8 +392,7 @@ def main():
|
|||
eval_dataset=eval_dataset if training_args.do_eval else None,
|
||||
compute_metrics=compute_metrics,
|
||||
tokenizer=tokenizer,
|
||||
# Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
|
||||
data_collator=default_data_collator if data_args.pad_to_max_length else None,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
# Training
|
||||
|
|
|
@ -327,7 +327,7 @@ def main():
|
|||
)
|
||||
|
||||
# Data collator
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8 if training_args.fp16 else None)
|
||||
|
||||
# Metrics
|
||||
metric = load_metric("seqeval")
|
||||
|
|
|
@ -33,6 +33,7 @@ from transformers import (
|
|||
AutoConfig,
|
||||
{{cookiecutter.model_class}},
|
||||
AutoTokenizer,
|
||||
DataCollatorWithPadding,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
|
@ -323,7 +324,7 @@ def main():
|
|||
)
|
||||
|
||||
# Data collator
|
||||
data_collator=default_data_collator
|
||||
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Trainer(
|
||||
|
|
Loading…
Reference in New Issue