Merge pull request #2046 from jplu/tf2-ner-example
Add NER TF2 example.
This commit is contained in:
commit
5482822a2b
|
@ -5,6 +5,7 @@ The ``.optimization`` module provides:
|
|||
|
||||
- an optimizer with weight decay fixed that can be used to fine-tuned models, and
|
||||
- several schedules in the form of schedule objects that inherit from ``_LRSchedule``:
|
||||
- a gradient accumulation class to accumulate the gradients of multiple batches
|
||||
|
||||
``AdamW``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
@ -12,6 +13,15 @@ The ``.optimization`` module provides:
|
|||
.. autoclass:: transformers.AdamW
|
||||
:members:
|
||||
|
||||
``AdamWeightDecay``
|
||||
~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.AdamWeightDecay
|
||||
:members:
|
||||
|
||||
.. autofunction:: transformers.create_optimizer
|
||||
:members:
|
||||
|
||||
Schedules
|
||||
----------------------------------------------------
|
||||
|
||||
|
@ -49,3 +59,17 @@ Learning Rate Schedules
|
|||
.. image:: /imgs/warmup_linear_schedule.png
|
||||
:target: /imgs/warmup_linear_schedule.png
|
||||
:alt:
|
||||
|
||||
``Warmup``
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.Warmup
|
||||
:members:
|
||||
|
||||
Gradient Strategies
|
||||
----------------------------------------------------
|
||||
|
||||
``GradientAccumulator``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: transformers.GradientAccumulator
|
||||
|
|
|
@ -467,7 +467,8 @@ Training with the previously defined hyper-parameters yields the following resul
|
|||
|
||||
## Named Entity Recognition
|
||||
|
||||
Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py).
|
||||
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and
|
||||
[`run_tf_ner.py`(https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py)] for Tensorflow 2.
|
||||
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
|
||||
Details and results for the fine-tuning provided by @stefan-it.
|
||||
|
||||
|
@ -512,7 +513,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so
|
|||
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
|
||||
```
|
||||
|
||||
### Training
|
||||
### Prepare the run
|
||||
|
||||
Additional environment variables must be set:
|
||||
|
||||
|
@ -524,6 +525,8 @@ export SAVE_STEPS=750
|
|||
export SEED=1
|
||||
```
|
||||
|
||||
### Run the Pytorch version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
|
@ -544,7 +547,7 @@ python3 run_ner.py --data_dir ./ \
|
|||
|
||||
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
### Evaluation
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
|
||||
|
@ -566,7 +569,7 @@ On the test dataset the following results could be achieved:
|
|||
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
|
||||
```
|
||||
|
||||
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
|
||||
|
||||
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
|
||||
|
||||
|
@ -576,6 +579,72 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a
|
|||
| `roberta-large` | 95.96 | 91.87
|
||||
| `distilbert-base-uncased` | 94.34 | 90.32
|
||||
|
||||
### Run the Tensorflow 2 version
|
||||
|
||||
To start training, just run:
|
||||
|
||||
```bash
|
||||
python3 run_tf_ner.py --data_dir ./ \
|
||||
--model_type bert \
|
||||
--labels ./labels.txt \
|
||||
--model_name_or_path $BERT_MODEL \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_seq_length $MAX_LENGTH \
|
||||
--num_train_epochs $NUM_EPOCHS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--save_steps $SAVE_STEPS \
|
||||
--seed $SEED \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_predict
|
||||
```
|
||||
|
||||
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
|
||||
|
||||
#### Evaluation
|
||||
|
||||
Evaluation on development dataset outputs the following for our example:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
LOCderiv 0.7619 0.6154 0.6809 52
|
||||
PERpart 0.8724 0.8997 0.8858 4057
|
||||
OTHpart 0.9360 0.9466 0.9413 711
|
||||
ORGpart 0.7015 0.6989 0.7002 269
|
||||
LOCpart 0.7668 0.8488 0.8057 496
|
||||
LOC 0.8745 0.9191 0.8963 235
|
||||
ORGderiv 0.7723 0.8571 0.8125 91
|
||||
OTHderiv 0.4800 0.6667 0.5581 18
|
||||
OTH 0.5789 0.6875 0.6286 16
|
||||
PERderiv 0.5385 0.3889 0.4516 18
|
||||
PER 0.5000 0.5000 0.5000 2
|
||||
ORG 0.0000 0.0000 0.0000 3
|
||||
|
||||
micro avg 0.8574 0.8862 0.8715 5968
|
||||
macro avg 0.8575 0.8862 0.8713 5968
|
||||
```
|
||||
|
||||
On the test dataset the following results could be achieved:
|
||||
```bash
|
||||
precision recall f1-score support
|
||||
|
||||
PERpart 0.8847 0.8944 0.8896 9397
|
||||
OTHpart 0.9376 0.9353 0.9365 1639
|
||||
ORGpart 0.7307 0.7044 0.7173 697
|
||||
LOC 0.9133 0.9394 0.9262 561
|
||||
LOCpart 0.8058 0.8157 0.8107 1150
|
||||
ORG 0.0000 0.0000 0.0000 8
|
||||
OTHderiv 0.5882 0.4762 0.5263 42
|
||||
PERderiv 0.6571 0.5227 0.5823 44
|
||||
OTH 0.4906 0.6667 0.5652 39
|
||||
ORGderiv 0.7016 0.7791 0.7383 172
|
||||
LOCderiv 0.8256 0.6514 0.7282 109
|
||||
PER 0.0000 0.0000 0.0000 11
|
||||
|
||||
micro avg 0.8722 0.8774 0.8748 13869
|
||||
macro avg 0.8712 0.8774 0.8740 13869
|
||||
```
|
||||
|
||||
## Abstractive summarization
|
||||
|
||||
Based on the script
|
||||
|
|
|
@ -0,0 +1,615 @@
|
|||
# coding=utf-8
|
||||
import datetime
|
||||
import os
|
||||
import math
|
||||
import glob
|
||||
import re
|
||||
import tensorflow as tf
|
||||
import collections
|
||||
import numpy as np
|
||||
from seqeval import metrics
|
||||
import _pickle as pickle
|
||||
from absl import logging
|
||||
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
|
||||
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
|
||||
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
|
||||
from transformers import create_optimizer, GradientAccumulator
|
||||
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
|
||||
from fastprogress import master_bar, progress_bar
|
||||
from absl import flags
|
||||
from absl import app
|
||||
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
|
||||
())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
|
||||
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
|
||||
}
|
||||
|
||||
|
||||
flags.DEFINE_string(
|
||||
"data_dir", None,
|
||||
"The input data dir. Should contain the .conll files (or other data files) "
|
||||
"for the task.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_type", None,
|
||||
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model_name_or_path", None,
|
||||
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
|
||||
|
||||
flags.DEFINE_string(
|
||||
"output_dir", None,
|
||||
"The output directory where the model checkpoints will be written.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"labels", "",
|
||||
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"config_name", "",
|
||||
"Pretrained config name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tokenizer_name", "",
|
||||
"Pretrained tokenizer name or path if not the same as model_name")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"cache_dir", "",
|
||||
"Where do you want to store the pre-trained models downloaded from s3")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_seq_length", 128,
|
||||
"The maximum total input sentence length after tokenization. "
|
||||
"Sequences longer than this will be truncated, sequences shorter "
|
||||
"will be padded.")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"tpu", None,
|
||||
"The Cloud TPU to use for training. This should be either the name "
|
||||
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
|
||||
"url.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_tpu_cores", 8,
|
||||
"Total number of TPU cores to use.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_train", False,
|
||||
"Whether to run training.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_eval", False,
|
||||
"Whether to run eval on the dev set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_predict", False,
|
||||
"Whether to run predictions on the test set.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"evaluate_during_training", False,
|
||||
"Whether to run evaluation during training at each logging step.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"do_lower_case", False,
|
||||
"Set this flag if you are using an uncased model.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_train_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for training.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"per_device_eval_batch_size", 8,
|
||||
"Batch size per GPU/CPU/TPU for evaluation.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"gradient_accumulation_steps", 1,
|
||||
"Number of updates steps to accumulate before performing a backward/update pass.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"learning_rate", 5e-5,
|
||||
"The initial learning rate for Adam.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"weight_decay", 0.0,
|
||||
"Weight decay if we apply some.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"adam_epsilon", 1e-8,
|
||||
"Epsilon for Adam optimizer.")
|
||||
|
||||
flags.DEFINE_float(
|
||||
"max_grad_norm", 1.0,
|
||||
"Max gradient norm.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"num_train_epochs", 3,
|
||||
"Total number of training epochs to perform.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"max_steps", -1,
|
||||
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"warmup_steps", 0,
|
||||
"Linear warmup over warmup_steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"logging_steps", 50,
|
||||
"Log every X updates steps.")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"save_steps", 50,
|
||||
"Save checkpoint every X updates steps.")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"eval_all_checkpoints", False,
|
||||
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"no_cuda", False,
|
||||
"Avoid using CUDA when available")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_output_dir", False,
|
||||
"Overwrite the content of the output directory")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"overwrite_cache", False,
|
||||
"Overwrite the cached training and evaluation sets")
|
||||
|
||||
flags.DEFINE_integer(
|
||||
"seed", 42,
|
||||
"random seed for initialization")
|
||||
|
||||
flags.DEFINE_boolean(
|
||||
"fp16", False,
|
||||
"Whether to use 16-bit (mixed) precision instead of 32-bit")
|
||||
|
||||
flags.DEFINE_string(
|
||||
"gpus", "0",
|
||||
"Comma separated list of gpus devices. If only one, switch to single "
|
||||
"gpu strategy, if None takes all the gpus available.")
|
||||
|
||||
|
||||
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
|
||||
if args['max_steps'] > 0:
|
||||
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
|
||||
args['num_train_epochs'] = 1
|
||||
else:
|
||||
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
|
||||
|
||||
writer = tf.summary.create_file_writer("/tmp/mylogs")
|
||||
|
||||
with strategy.scope():
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
|
||||
|
||||
if args['fp16']:
|
||||
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
|
||||
|
||||
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
|
||||
gradient_accumulator = GradientAccumulator()
|
||||
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(" Num examples = %d", num_train_examples)
|
||||
logging.info(" Num Epochs = %d", args['num_train_epochs'])
|
||||
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
|
||||
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
train_batch_size * args['gradient_accumulation_steps'])
|
||||
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
|
||||
logging.info(" Total training steps = %d", num_train_steps)
|
||||
|
||||
model.summary()
|
||||
|
||||
@tf.function
|
||||
def apply_gradients():
|
||||
grads_and_vars = []
|
||||
|
||||
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
|
||||
if gradient is not None:
|
||||
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
|
||||
grads_and_vars.append((scaled_gradient, variable))
|
||||
else:
|
||||
grads_and_vars.append((gradient, variable))
|
||||
|
||||
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
|
||||
gradient_accumulator.reset()
|
||||
|
||||
@tf.function
|
||||
def train_step(train_features, train_labels):
|
||||
def step_fn(train_features, train_labels):
|
||||
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(train_features['input_ids'], **inputs)[0]
|
||||
logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(train_features['input_mask'], (-1,))
|
||||
active_logits = tf.boolean_mask(logits, active_loss)
|
||||
train_labels = tf.reshape(train_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(train_labels, active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
|
||||
grads = tape.gradient(loss, model.trainable_variables)
|
||||
|
||||
gradient_accumulator(grads)
|
||||
|
||||
return cross_entropy
|
||||
|
||||
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
|
||||
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
|
||||
|
||||
return mean_loss
|
||||
|
||||
current_time = datetime.datetime.now()
|
||||
train_iterator = master_bar(range(args['num_train_epochs']))
|
||||
global_step = 0
|
||||
logging_loss = 0.0
|
||||
|
||||
for epoch in train_iterator:
|
||||
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
|
||||
step = 1
|
||||
|
||||
with strategy.scope():
|
||||
for train_features, train_labels in epoch_iterator:
|
||||
loss = train_step(train_features, train_labels)
|
||||
|
||||
if step % args['gradient_accumulation_steps'] == 0:
|
||||
strategy.experimental_run_v2(apply_gradients)
|
||||
|
||||
loss_metric(loss)
|
||||
|
||||
global_step += 1
|
||||
|
||||
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
|
||||
# Log metrics
|
||||
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("Eval at step " + str(global_step) + "\n" + report)
|
||||
logging.info("eval_loss: " + str(eval_loss))
|
||||
|
||||
precision = metrics.precision_score(y_true, y_pred)
|
||||
recall = metrics.recall_score(y_true, y_pred)
|
||||
f1 = metrics.f1_score(y_true, y_pred)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("eval_loss", eval_loss, global_step)
|
||||
tf.summary.scalar("precision", precision, global_step)
|
||||
tf.summary.scalar("recall", recall, global_step)
|
||||
tf.summary.scalar("f1", f1, global_step)
|
||||
|
||||
lr = optimizer.learning_rate
|
||||
learning_rate = lr(step)
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("lr", learning_rate, global_step)
|
||||
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
|
||||
|
||||
logging_loss = loss_metric.result()
|
||||
|
||||
with writer.as_default():
|
||||
tf.summary.scalar("loss", loss_metric.result(), step=step)
|
||||
|
||||
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
|
||||
# Save model checkpoint
|
||||
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
|
||||
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
model.save_pretrained(output_dir)
|
||||
logging.info("Saving model checkpoint to %s", output_dir)
|
||||
|
||||
train_iterator.child.comment = f'loss : {loss_metric.result()}'
|
||||
step += 1
|
||||
|
||||
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
|
||||
|
||||
loss_metric.reset_states()
|
||||
|
||||
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
|
||||
|
||||
|
||||
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
|
||||
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
|
||||
preds = None
|
||||
num_eval_steps = math.ceil(size / eval_batch_size)
|
||||
master = master_bar(range(1))
|
||||
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
|
||||
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
|
||||
loss = 0.0
|
||||
|
||||
logging.info("***** Running evaluation *****")
|
||||
logging.info(" Num examples = %d", size)
|
||||
logging.info(" Batch size = %d", eval_batch_size)
|
||||
|
||||
for eval_features, eval_labels in eval_iterator:
|
||||
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
|
||||
|
||||
if args['model_type'] != "distilbert":
|
||||
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
|
||||
|
||||
with strategy.scope():
|
||||
logits = model(eval_features['input_ids'], **inputs)[0]
|
||||
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
|
||||
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
|
||||
active_logits = tf.boolean_mask(tmp_logits, active_loss)
|
||||
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
|
||||
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
|
||||
cross_entropy = loss_fct(active_labels, active_logits)
|
||||
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
|
||||
|
||||
if preds is None:
|
||||
preds = logits.numpy()
|
||||
label_ids = eval_labels.numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.numpy(), axis=0)
|
||||
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
|
||||
|
||||
preds = np.argmax(preds, axis=2)
|
||||
y_pred = [[] for _ in range(label_ids.shape[0])]
|
||||
y_true = [[] for _ in range(label_ids.shape[0])]
|
||||
loss = loss / num_eval_steps
|
||||
|
||||
for i in range(label_ids.shape[0]):
|
||||
for j in range(label_ids.shape[1]):
|
||||
if label_ids[i, j] != pad_token_label_id:
|
||||
y_pred[i].append(labels[preds[i, j] - 1])
|
||||
y_true[i].append(labels[label_ids[i, j] - 1])
|
||||
|
||||
return y_true, y_pred, loss.numpy()
|
||||
|
||||
|
||||
def load_cache(cached_file, max_seq_length):
|
||||
name_to_features = {
|
||||
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
|
||||
}
|
||||
|
||||
def _decode_record(record):
|
||||
example = tf.io.parse_single_example(record, name_to_features)
|
||||
features = {}
|
||||
features['input_ids'] = example['input_ids']
|
||||
features['input_mask'] = example['input_mask']
|
||||
features['segment_ids'] = example['segment_ids']
|
||||
|
||||
return features, example['label_ids']
|
||||
|
||||
d = tf.data.TFRecordDataset(cached_file)
|
||||
d = d.map(_decode_record, num_parallel_calls=4)
|
||||
count = d.reduce(0, lambda x, _: x + 1)
|
||||
|
||||
return d, count.numpy()
|
||||
|
||||
|
||||
def save_cache(features, cached_features_file):
|
||||
writer = tf.io.TFRecordWriter(cached_features_file)
|
||||
|
||||
for (ex_index, feature) in enumerate(features):
|
||||
if ex_index % 5000 == 0:
|
||||
logging.info("Writing example %d of %d" % (ex_index, len(features)))
|
||||
|
||||
def create_int_feature(values):
|
||||
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
|
||||
return f
|
||||
|
||||
record_feature = collections.OrderedDict()
|
||||
record_feature["input_ids"] = create_int_feature(feature.input_ids)
|
||||
record_feature["input_mask"] = create_int_feature(feature.input_mask)
|
||||
record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
|
||||
record_feature["label_ids"] = create_int_feature(feature.label_ids)
|
||||
|
||||
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
|
||||
|
||||
writer.write(tf_example.SerializeToString())
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
|
||||
drop_remainder = True if args['tpu'] or mode == 'train' else False
|
||||
|
||||
# Load data features from cache or dataset file
|
||||
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
|
||||
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
|
||||
str(args['max_seq_length'])))
|
||||
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
|
||||
logging.info("Loading features from cached file %s", cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
else:
|
||||
logging.info("Creating features from dataset file at %s", args['data_dir'])
|
||||
examples = read_examples_from_file(args['data_dir'], mode)
|
||||
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
|
||||
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
|
||||
# xlnet has a cls token at the end
|
||||
cls_token=tokenizer.cls_token,
|
||||
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
|
||||
sep_token=tokenizer.sep_token,
|
||||
sep_token_extra=bool(args['model_type'] in ["roberta"]),
|
||||
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
|
||||
pad_on_left=bool(args['model_type'] in ["xlnet"]),
|
||||
# pad on the left for xlnet
|
||||
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
||||
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
|
||||
pad_token_label_id=pad_token_label_id
|
||||
)
|
||||
logging.info("Saving features into cached file %s", cached_features_file)
|
||||
save_cache(features, cached_features_file)
|
||||
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
|
||||
|
||||
if mode == 'train':
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
|
||||
|
||||
dataset = dataset.batch(batch_size, drop_remainder)
|
||||
dataset = dataset.prefetch(buffer_size=batch_size)
|
||||
|
||||
return dataset, size
|
||||
|
||||
|
||||
def main(_):
|
||||
logging.set_verbosity(logging.INFO)
|
||||
args = flags.FLAGS.flag_values_dict()
|
||||
|
||||
if os.path.exists(args['output_dir']) and os.listdir(
|
||||
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
|
||||
raise ValueError(
|
||||
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
||||
args['output_dir']))
|
||||
|
||||
if args['fp16']:
|
||||
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
|
||||
|
||||
if args['tpu']:
|
||||
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
|
||||
tf.config.experimental_connect_to_cluster(resolver)
|
||||
tf.tpu.experimental.initialize_tpu_system(resolver)
|
||||
strategy = tf.distribute.experimental.TPUStrategy(resolver)
|
||||
args['n_device'] = args['num_tpu_cores']
|
||||
elif len(args['gpus'].split(',')) > 1:
|
||||
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
|
||||
elif args['no_cuda']:
|
||||
args['n_device'] = 1
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
|
||||
else:
|
||||
args['n_device'] = len(args['gpus'].split(','))
|
||||
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
|
||||
|
||||
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
|
||||
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
|
||||
|
||||
labels = get_labels(args['labels'])
|
||||
num_labels = len(labels) + 1
|
||||
pad_token_label_id = 0
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
|
||||
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
|
||||
num_labels=num_labels,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
|
||||
logging.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args['do_train']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
|
||||
do_lower_case=args['do_lower_case'],
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(args['model_name_or_path'],
|
||||
from_pt=bool(".bin" in args['model_name_or_path']),
|
||||
config=config,
|
||||
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
|
||||
model.layers[-1].activation = tf.keras.activations.softmax
|
||||
|
||||
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
|
||||
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
|
||||
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
|
||||
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
|
||||
|
||||
if not os.path.exists(args['output_dir']):
|
||||
os.makedirs(args['output_dir'])
|
||||
|
||||
logging.info("Saving model to %s", args['output_dir'])
|
||||
|
||||
model.save_pretrained(args['output_dir'])
|
||||
tokenizer.save_pretrained(args['output_dir'])
|
||||
|
||||
# Evaluation
|
||||
if args['do_eval']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
checkpoints = []
|
||||
results = []
|
||||
|
||||
if args['eval_all_checkpoints']:
|
||||
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
|
||||
|
||||
logging.info("Evaluate the following checkpoints: %s", checkpoints)
|
||||
|
||||
if len(checkpoints) == 0:
|
||||
checkpoints.append(args['output_dir'])
|
||||
|
||||
for checkpoint in checkpoints:
|
||||
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
|
||||
|
||||
with strategy.scope():
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
|
||||
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
if global_step:
|
||||
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
|
||||
|
||||
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
|
||||
|
||||
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
|
||||
for res in results:
|
||||
for key, val in res.items():
|
||||
if "loss" in key:
|
||||
logging.info(key + " = " + str(val))
|
||||
writer.write(key + " = " + str(val))
|
||||
writer.write("\n")
|
||||
else:
|
||||
logging.info(key)
|
||||
logging.info("\n" + report)
|
||||
writer.write(key + "\n")
|
||||
writer.write(report)
|
||||
writer.write("\n")
|
||||
|
||||
if args['do_predict']:
|
||||
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
|
||||
model = model_class.from_pretrained(args['output_dir'])
|
||||
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
|
||||
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
|
||||
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
|
||||
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
|
||||
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
|
||||
report = metrics.classification_report(y_true, y_pred, digits=4)
|
||||
|
||||
logging.info("\n" + report)
|
||||
|
||||
writer.write(report)
|
||||
writer.write("\n\nloss = " + str(pred_loss))
|
||||
|
||||
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
|
||||
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
|
||||
example_id = 0
|
||||
|
||||
for line in f:
|
||||
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
|
||||
writer.write(line)
|
||||
|
||||
if not y_pred[example_id]:
|
||||
example_id += 1
|
||||
elif y_pred[example_id]:
|
||||
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
|
||||
writer.write(output_line)
|
||||
else:
|
||||
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
flags.mark_flag_as_required("data_dir")
|
||||
flags.mark_flag_as_required("output_dir")
|
||||
flags.mark_flag_as_required("model_name_or_path")
|
||||
flags.mark_flag_as_required("model_type")
|
||||
app.run(main)
|
|
@ -164,6 +164,7 @@ if is_tf_available():
|
|||
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
|
||||
TFDistilBertModel, TFDistilBertForMaskedLM,
|
||||
TFDistilBertForSequenceClassification,
|
||||
TFDistilBertForTokenClassification,
|
||||
TFDistilBertForQuestionAnswering,
|
||||
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
|
||||
|
@ -174,6 +175,8 @@ if is_tf_available():
|
|||
from .modeling_tf_albert import (TFAlbertPreTrainedModel, TFAlbertModel, TFAlbertForMaskedLM,
|
||||
TFAlbertForSequenceClassification,
|
||||
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
# Optimization
|
||||
from .optimization_tf import (WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator)
|
||||
|
||||
# TF 2.0 <=> PyTorch conversion utilities
|
||||
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,
|
||||
|
|
|
@ -704,6 +704,53 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
|
|||
return outputs # logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of
|
||||
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
||||
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
|
||||
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
|
||||
Classification scores (before SoftMax).
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
Examples::
|
||||
import tensorflow as tf
|
||||
from transformers import DistilBertTokenizer, TFDistilBertForTokenClassification
|
||||
tokenizer = DistilBertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = TFDistilBertForTokenClassification.from_pretrained('bert-base-uncased')
|
||||
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
scores = outputs[0]
|
||||
"""
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super(TFDistilBertForTokenClassification, self).__init__(config, *inputs, **kwargs)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.distilbert = TFDistilBertMainLayer(config, name='distilbert')
|
||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||
self.classifier = tf.keras.layers.Dense(config.num_labels,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name='classifier')
|
||||
|
||||
def call(self, inputs, **kwargs):
|
||||
outputs = self.distilbert(inputs, **kwargs)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||
|
||||
return outputs # scores, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
||||
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
|
||||
|
|
|
@ -0,0 +1,254 @@
|
|||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Functions and classes related to optimization (weight updates)."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||
"""Applys a warmup schedule on a given learning rate decay schedule."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
initial_learning_rate,
|
||||
decay_schedule_fn,
|
||||
warmup_steps,
|
||||
power=1.0,
|
||||
name=None):
|
||||
super(WarmUp, self).__init__()
|
||||
self.initial_learning_rate = initial_learning_rate
|
||||
self.warmup_steps = warmup_steps
|
||||
self.power = power
|
||||
self.decay_schedule_fn = decay_schedule_fn
|
||||
self.name = name
|
||||
|
||||
def __call__(self, step):
|
||||
with tf.name_scope(self.name or 'WarmUp') as name:
|
||||
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
|
||||
# learning rate will be `global_step/num_warmup_steps * init_lr`.
|
||||
global_step_float = tf.cast(step, tf.float32)
|
||||
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
|
||||
warmup_percent_done = global_step_float / warmup_steps_float
|
||||
warmup_learning_rate = (
|
||||
self.initial_learning_rate *
|
||||
tf.math.pow(warmup_percent_done, self.power))
|
||||
return tf.cond(global_step_float < warmup_steps_float,
|
||||
lambda: warmup_learning_rate,
|
||||
lambda: self.decay_schedule_fn(step),
|
||||
name=name)
|
||||
|
||||
def get_config(self):
|
||||
return {
|
||||
'initial_learning_rate': self.initial_learning_rate,
|
||||
'decay_schedule_fn': self.decay_schedule_fn,
|
||||
'warmup_steps': self.warmup_steps,
|
||||
'power': self.power,
|
||||
'name': self.name
|
||||
}
|
||||
|
||||
|
||||
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
|
||||
"""Creates an optimizer with learning rate schedule."""
|
||||
# Implements linear decay of the learning rate.
|
||||
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||
initial_learning_rate=init_lr,
|
||||
decay_steps=num_train_steps,
|
||||
end_learning_rate=0.0)
|
||||
if num_warmup_steps:
|
||||
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
|
||||
decay_schedule_fn=learning_rate_fn,
|
||||
warmup_steps=num_warmup_steps)
|
||||
optimizer = AdamWeightDecay(
|
||||
learning_rate=learning_rate_fn,
|
||||
weight_decay_rate=0.01,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-6,
|
||||
exclude_from_weight_decay=['layer_norm', 'bias'])
|
||||
return optimizer
|
||||
|
||||
|
||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
|
||||
|
||||
Just adding the square of the weights to the loss function is *not* the
|
||||
correct way of using L2 regularization/weight decay with Adam, since that will
|
||||
interact with the m and v parameters in strange ways.
|
||||
|
||||
Instead we want ot decay the weights in a manner that doesn't interact with
|
||||
the m/v parameters. This is equivalent to adding the square of the weights to
|
||||
the loss with plain (non-momentum) SGD.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta_1=0.9,
|
||||
beta_2=0.999,
|
||||
epsilon=1e-7,
|
||||
amsgrad=False,
|
||||
weight_decay_rate=0.0,
|
||||
include_in_weight_decay=None,
|
||||
exclude_from_weight_decay=None,
|
||||
name='AdamWeightDecay',
|
||||
**kwargs):
|
||||
super(AdamWeightDecay, self).__init__(
|
||||
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
||||
self.weight_decay_rate = weight_decay_rate
|
||||
self._include_in_weight_decay = include_in_weight_decay
|
||||
self._exclude_from_weight_decay = exclude_from_weight_decay
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
"""Creates an optimizer from its config with WarmUp custom object."""
|
||||
custom_objects = {'WarmUp': WarmUp}
|
||||
return super(AdamWeightDecay, cls).from_config(
|
||||
config, custom_objects=custom_objects)
|
||||
|
||||
def _prepare_local(self, var_device, var_dtype, apply_state):
|
||||
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
|
||||
apply_state)
|
||||
apply_state['weight_decay_rate'] = tf.constant(
|
||||
self.weight_decay_rate, name='adam_weight_decay_rate')
|
||||
|
||||
def _decay_weights_op(self, var, learning_rate, apply_state):
|
||||
do_decay = self._do_use_weight_decay(var.name)
|
||||
if do_decay:
|
||||
return var.assign_sub(
|
||||
learning_rate * var *
|
||||
apply_state['weight_decay_rate'],
|
||||
use_locking=self._use_locking)
|
||||
return tf.no_op()
|
||||
|
||||
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
|
||||
grads, tvars = list(zip(*grads_and_vars))
|
||||
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
|
||||
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
|
||||
|
||||
def _get_lr(self, var_device, var_dtype, apply_state):
|
||||
"""Retrieves the learning rate with the given state."""
|
||||
if apply_state is None:
|
||||
return self._decayed_lr_t[var_dtype], {}
|
||||
|
||||
apply_state = apply_state or {}
|
||||
coefficients = apply_state.get((var_device, var_dtype))
|
||||
if coefficients is None:
|
||||
coefficients = self._fallback_apply_state(var_device, var_dtype)
|
||||
apply_state[(var_device, var_dtype)] = coefficients
|
||||
|
||||
return coefficients['lr_t'], dict(apply_state=apply_state)
|
||||
|
||||
def _resource_apply_dense(self, grad, var, apply_state=None):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super(AdamWeightDecay, self)._resource_apply_dense(
|
||||
grad, var, **kwargs)
|
||||
|
||||
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
|
||||
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
|
||||
decay = self._decay_weights_op(var, lr_t, apply_state)
|
||||
with tf.control_dependencies([decay]):
|
||||
return super(AdamWeightDecay, self)._resource_apply_sparse(
|
||||
grad, var, indices, **kwargs)
|
||||
|
||||
def get_config(self):
|
||||
config = super(AdamWeightDecay, self).get_config()
|
||||
config.update({
|
||||
'weight_decay_rate': self.weight_decay_rate,
|
||||
})
|
||||
return config
|
||||
|
||||
def _do_use_weight_decay(self, param_name):
|
||||
"""Whether to use L2 weight decay for `param_name`."""
|
||||
if self.weight_decay_rate == 0:
|
||||
return False
|
||||
|
||||
if self._include_in_weight_decay:
|
||||
for r in self._include_in_weight_decay:
|
||||
if re.search(r, param_name) is not None:
|
||||
return True
|
||||
|
||||
if self._exclude_from_weight_decay:
|
||||
for r in self._exclude_from_weight_decay:
|
||||
if re.search(r, param_name) is not None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
|
||||
class GradientAccumulator(object):
|
||||
"""Distribution strategies-aware gradient accumulation utility."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initializes the accumulator."""
|
||||
self._gradients = []
|
||||
self._accum_steps = tf.Variable(
|
||||
initial_value=0,
|
||||
dtype=tf.int64,
|
||||
trainable=False,
|
||||
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
|
||||
|
||||
@property
|
||||
def step(self):
|
||||
"""Number of accumulated steps."""
|
||||
return self._accum_steps.value()
|
||||
|
||||
@property
|
||||
def gradients(self):
|
||||
"""The accumulated gradients."""
|
||||
return list(gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients())
|
||||
|
||||
def __call__(self, gradients):
|
||||
"""Accumulates :obj:`gradients`."""
|
||||
if not self._gradients:
|
||||
self._gradients.extend([tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient for gradient in gradients])
|
||||
|
||||
if len(gradients) != len(self._gradients):
|
||||
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
|
||||
|
||||
for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
|
||||
if accum_gradient is not None:
|
||||
accum_gradient.assign_add(gradient)
|
||||
|
||||
self._accum_steps.assign_add(1)
|
||||
|
||||
def reset(self):
|
||||
"""Resets the accumulated gradients."""
|
||||
if self._gradients:
|
||||
self._accum_steps.assign(0)
|
||||
|
||||
for gradient in self._get_replica_gradients():
|
||||
if gradient is not None:
|
||||
gradient.assign(tf.zeros_like(gradient))
|
||||
|
||||
def _get_replica_gradients(self):
|
||||
if tf.distribute.has_strategy():
|
||||
# In a replica context, we want to accumulate gradients on each replica
|
||||
# without synchronization, so we directly assign the value of the
|
||||
# current replica.
|
||||
replica_context = tf.distribute.get_replica_context()
|
||||
|
||||
if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
|
||||
return self._gradients
|
||||
|
||||
return (gradient.device_map.select_for_current_replica(gradient.values, replica_context) for gradient in self._gradients)
|
||||
else:
|
||||
return self._gradients
|
|
@ -0,0 +1,89 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
|
||||
from transformers import is_tf_available
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from transformers import (create_optimizer, GradientAccumulator)
|
||||
else:
|
||||
pytestmark = pytest.mark.skip("Require TensorFlow")
|
||||
|
||||
class OptimizationFTest(unittest.TestCase):
|
||||
def assertListAlmostEqual(self, list1, list2, tol):
|
||||
self.assertEqual(len(list1), len(list2))
|
||||
for a, b in zip(list1, list2):
|
||||
self.assertAlmostEqual(a, b, delta=tol)
|
||||
|
||||
def testGradientAccumulator(self):
|
||||
accumulator = GradientAccumulator()
|
||||
accumulator([tf.constant([1.0, 2.0])])
|
||||
accumulator([tf.constant([-2.0, 1.0])])
|
||||
accumulator([tf.constant([-1.0, 2.0])])
|
||||
with self.assertRaises(ValueError):
|
||||
accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertEqual(len(accumulator.gradients), 1)
|
||||
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
|
||||
def testGradientAccumulatorDistributionStrategy(self):
|
||||
context._context = None
|
||||
ops.enable_eager_execution_internal()
|
||||
physical_devices = tf.config.experimental.list_physical_devices("CPU")
|
||||
tf.config.experimental.set_virtual_device_configuration(
|
||||
physical_devices[0],
|
||||
[tf.config.experimental.VirtualDeviceConfiguration(),
|
||||
tf.config.experimental.VirtualDeviceConfiguration()])
|
||||
|
||||
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
|
||||
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
|
||||
|
||||
with strategy.scope():
|
||||
accumulator = GradientAccumulator()
|
||||
variable = tf.Variable([4.0, 3.0])
|
||||
optimizer = create_optimizer(5e-5, 10, 5)
|
||||
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
|
||||
|
||||
def accumulate_on_replica(gradient):
|
||||
accumulator([gradient])
|
||||
|
||||
def apply_on_replica():
|
||||
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
|
||||
|
||||
@tf.function
|
||||
def accumulate(grad1, grad2):
|
||||
with strategy.scope():
|
||||
gradient_placeholder.values[0].assign(grad1)
|
||||
gradient_placeholder.values[1].assign(grad2)
|
||||
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
|
||||
|
||||
@tf.function
|
||||
def apply_grad():
|
||||
with strategy.scope():
|
||||
strategy.experimental_run_v2(apply_on_replica)
|
||||
|
||||
accumulate([1.0, 2.0], [-1.0, 1.0])
|
||||
accumulate([3.0, -1.0], [-1.0, -1.0])
|
||||
accumulate([-2.0, 2.0], [3.0, -2.0])
|
||||
self.assertEqual(accumulator.step, 3)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
|
||||
apply_grad()
|
||||
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
|
||||
accumulator.reset()
|
||||
self.assertEqual(accumulator.step, 0)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue