Merge pull request #2046 from jplu/tf2-ner-example

Add NER TF2 example.
This commit is contained in:
Thomas Wolf 2019-12-06 12:12:22 +01:00 committed by GitHub
commit 5482822a2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1105 additions and 4 deletions

View File

@ -5,6 +5,7 @@ The ``.optimization`` module provides:
- an optimizer with weight decay fixed that can be used to fine-tuned models, and
- several schedules in the form of schedule objects that inherit from ``_LRSchedule``:
- a gradient accumulation class to accumulate the gradients of multiple batches
``AdamW``
~~~~~~~~~~~~~~~~
@ -12,6 +13,15 @@ The ``.optimization`` module provides:
.. autoclass:: transformers.AdamW
:members:
``AdamWeightDecay``
~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.AdamWeightDecay
:members:
.. autofunction:: transformers.create_optimizer
:members:
Schedules
----------------------------------------------------
@ -49,3 +59,17 @@ Learning Rate Schedules
.. image:: /imgs/warmup_linear_schedule.png
:target: /imgs/warmup_linear_schedule.png
:alt:
``Warmup``
~~~~~~~~~~~~~~~~
.. autoclass:: transformers.Warmup
:members:
Gradient Strategies
----------------------------------------------------
``GradientAccumulator``
~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.GradientAccumulator

View File

@ -467,7 +467,8 @@ Training with the previously defined hyper-parameters yields the following resul
## Named Entity Recognition
Based on the script [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py).
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/run_ner.py) for Pytorch and
[`run_tf_ner.py`(https://github.com/huggingface/transformers/blob/master/examples/run_tf_ner.py)] for Tensorflow 2.
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
Details and results for the fine-tuning provided by @stefan-it.
@ -512,7 +513,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
```
### Training
### Prepare the run
Additional environment variables must be set:
@ -524,6 +525,8 @@ export SAVE_STEPS=750
export SEED=1
```
### Run the Pytorch version
To start training, just run:
```bash
@ -544,7 +547,7 @@ python3 run_ner.py --data_dir ./ \
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
### Evaluation
#### Evaluation
Evaluation on development dataset outputs the following for our example:
@ -566,7 +569,7 @@ On the test dataset the following results could be achieved:
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
```
### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
@ -576,6 +579,72 @@ Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) a
| `roberta-large` | 95.96 | 91.87
| `distilbert-base-uncased` | 94.34 | 90.32
### Run the Tensorflow 2 version
To start training, just run:
```bash
python3 run_tf_ner.py --data_dir ./ \
--model_type bert \
--labels ./labels.txt \
--model_name_or_path $BERT_MODEL \
--output_dir $OUTPUT_DIR \
--max_seq_length $MAX_LENGTH \
--num_train_epochs $NUM_EPOCHS \
--per_device_train_batch_size $BATCH_SIZE \
--save_steps $SAVE_STEPS \
--seed $SEED \
--do_train \
--do_eval \
--do_predict
```
Such as the Pytorch version, if your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
#### Evaluation
Evaluation on development dataset outputs the following for our example:
```bash
precision recall f1-score support
LOCderiv 0.7619 0.6154 0.6809 52
PERpart 0.8724 0.8997 0.8858 4057
OTHpart 0.9360 0.9466 0.9413 711
ORGpart 0.7015 0.6989 0.7002 269
LOCpart 0.7668 0.8488 0.8057 496
LOC 0.8745 0.9191 0.8963 235
ORGderiv 0.7723 0.8571 0.8125 91
OTHderiv 0.4800 0.6667 0.5581 18
OTH 0.5789 0.6875 0.6286 16
PERderiv 0.5385 0.3889 0.4516 18
PER 0.5000 0.5000 0.5000 2
ORG 0.0000 0.0000 0.0000 3
micro avg 0.8574 0.8862 0.8715 5968
macro avg 0.8575 0.8862 0.8713 5968
```
On the test dataset the following results could be achieved:
```bash
precision recall f1-score support
PERpart 0.8847 0.8944 0.8896 9397
OTHpart 0.9376 0.9353 0.9365 1639
ORGpart 0.7307 0.7044 0.7173 697
LOC 0.9133 0.9394 0.9262 561
LOCpart 0.8058 0.8157 0.8107 1150
ORG 0.0000 0.0000 0.0000 8
OTHderiv 0.5882 0.4762 0.5263 42
PERderiv 0.6571 0.5227 0.5823 44
OTH 0.4906 0.6667 0.5652 39
ORGderiv 0.7016 0.7791 0.7383 172
LOCderiv 0.8256 0.6514 0.7282 109
PER 0.0000 0.0000 0.0000 11
micro avg 0.8722 0.8774 0.8748 13869
macro avg 0.8712 0.8774 0.8740 13869
```
## Abstractive summarization
Based on the script

615
examples/run_tf_ner.py Normal file
View File

@ -0,0 +1,615 @@
# coding=utf-8
import datetime
import os
import math
import glob
import re
import tensorflow as tf
import collections
import numpy as np
from seqeval import metrics
import _pickle as pickle
from absl import logging
from transformers import TF2_WEIGHTS_NAME, BertConfig, BertTokenizer, TFBertForTokenClassification
from transformers import RobertaConfig, RobertaTokenizer, TFRobertaForTokenClassification
from transformers import DistilBertConfig, DistilBertTokenizer, TFDistilBertForTokenClassification
from transformers import create_optimizer, GradientAccumulator
from utils_ner import convert_examples_to_features, get_labels, read_examples_from_file
from fastprogress import master_bar, progress_bar
from absl import flags
from absl import app
ALL_MODELS = sum(
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, RobertaConfig, DistilBertConfig)),
())
MODEL_CLASSES = {
"bert": (BertConfig, TFBertForTokenClassification, BertTokenizer),
"roberta": (RobertaConfig, TFRobertaForTokenClassification, RobertaTokenizer),
"distilbert": (DistilBertConfig, TFDistilBertForTokenClassification, DistilBertTokenizer)
}
flags.DEFINE_string(
"data_dir", None,
"The input data dir. Should contain the .conll files (or other data files) "
"for the task.")
flags.DEFINE_string(
"model_type", None,
"Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
flags.DEFINE_string(
"model_name_or_path", None,
"Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
flags.DEFINE_string(
"output_dir", None,
"The output directory where the model checkpoints will be written.")
flags.DEFINE_string(
"labels", "",
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used.")
flags.DEFINE_string(
"config_name", "",
"Pretrained config name or path if not the same as model_name")
flags.DEFINE_string(
"tokenizer_name", "",
"Pretrained tokenizer name or path if not the same as model_name")
flags.DEFINE_string(
"cache_dir", "",
"Where do you want to store the pre-trained models downloaded from s3")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sentence length after tokenization. "
"Sequences longer than this will be truncated, sequences shorter "
"will be padded.")
flags.DEFINE_string(
"tpu", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Total number of TPU cores to use.")
flags.DEFINE_boolean(
"do_train", False,
"Whether to run training.")
flags.DEFINE_boolean(
"do_eval", False,
"Whether to run eval on the dev set.")
flags.DEFINE_boolean(
"do_predict", False,
"Whether to run predictions on the test set.")
flags.DEFINE_boolean(
"evaluate_during_training", False,
"Whether to run evaluation during training at each logging step.")
flags.DEFINE_boolean(
"do_lower_case", False,
"Set this flag if you are using an uncased model.")
flags.DEFINE_integer(
"per_device_train_batch_size", 8,
"Batch size per GPU/CPU/TPU for training.")
flags.DEFINE_integer(
"per_device_eval_batch_size", 8,
"Batch size per GPU/CPU/TPU for evaluation.")
flags.DEFINE_integer(
"gradient_accumulation_steps", 1,
"Number of updates steps to accumulate before performing a backward/update pass.")
flags.DEFINE_float(
"learning_rate", 5e-5,
"The initial learning rate for Adam.")
flags.DEFINE_float(
"weight_decay", 0.0,
"Weight decay if we apply some.")
flags.DEFINE_float(
"adam_epsilon", 1e-8,
"Epsilon for Adam optimizer.")
flags.DEFINE_float(
"max_grad_norm", 1.0,
"Max gradient norm.")
flags.DEFINE_integer(
"num_train_epochs", 3,
"Total number of training epochs to perform.")
flags.DEFINE_integer(
"max_steps", -1,
"If > 0: set total number of training steps to perform. Override num_train_epochs.")
flags.DEFINE_integer(
"warmup_steps", 0,
"Linear warmup over warmup_steps.")
flags.DEFINE_integer(
"logging_steps", 50,
"Log every X updates steps.")
flags.DEFINE_integer(
"save_steps", 50,
"Save checkpoint every X updates steps.")
flags.DEFINE_boolean(
"eval_all_checkpoints", False,
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
flags.DEFINE_boolean(
"no_cuda", False,
"Avoid using CUDA when available")
flags.DEFINE_boolean(
"overwrite_output_dir", False,
"Overwrite the content of the output directory")
flags.DEFINE_boolean(
"overwrite_cache", False,
"Overwrite the cached training and evaluation sets")
flags.DEFINE_integer(
"seed", 42,
"random seed for initialization")
flags.DEFINE_boolean(
"fp16", False,
"Whether to use 16-bit (mixed) precision instead of 32-bit")
flags.DEFINE_string(
"gpus", "0",
"Comma separated list of gpus devices. If only one, switch to single "
"gpu strategy, if None takes all the gpus available.")
def train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id):
if args['max_steps'] > 0:
num_train_steps = args['max_steps'] * args['gradient_accumulation_steps']
args['num_train_epochs'] = 1
else:
num_train_steps = math.ceil(num_train_examples / train_batch_size) // args['gradient_accumulation_steps'] * args['num_train_epochs']
writer = tf.summary.create_file_writer("/tmp/mylogs")
with strategy.scope():
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
optimizer = create_optimizer(args['learning_rate'], num_train_steps, args['warmup_steps'])
if args['fp16']:
optimizer = tf.keras.mixed_precision.experimental.LossScaleOptimizer(optimizer, 'dynamic')
loss_metric = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
gradient_accumulator = GradientAccumulator()
logging.info("***** Running training *****")
logging.info(" Num examples = %d", num_train_examples)
logging.info(" Num Epochs = %d", args['num_train_epochs'])
logging.info(" Instantaneous batch size per device = %d", args['per_device_train_batch_size'])
logging.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
train_batch_size * args['gradient_accumulation_steps'])
logging.info(" Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
logging.info(" Total training steps = %d", num_train_steps)
model.summary()
@tf.function
def apply_gradients():
grads_and_vars = []
for gradient, variable in zip(gradient_accumulator.gradients, model.trainable_variables):
if gradient is not None:
scaled_gradient = gradient / (args['n_device'] * args['gradient_accumulation_steps'])
grads_and_vars.append((scaled_gradient, variable))
else:
grads_and_vars.append((gradient, variable))
optimizer.apply_gradients(grads_and_vars, args['max_grad_norm'])
gradient_accumulator.reset()
@tf.function
def train_step(train_features, train_labels):
def step_fn(train_features, train_labels):
inputs = {'attention_mask': train_features['input_mask'], 'training': True}
if args['model_type'] != "distilbert":
inputs["token_type_ids"] = train_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
with tf.GradientTape() as tape:
logits = model(train_features['input_ids'], **inputs)[0]
logits = tf.reshape(logits, (-1, len(labels) + 1))
active_loss = tf.reshape(train_features['input_mask'], (-1,))
active_logits = tf.boolean_mask(logits, active_loss)
train_labels = tf.reshape(train_labels, (-1,))
active_labels = tf.boolean_mask(train_labels, active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss = tf.reduce_sum(cross_entropy) * (1.0 / train_batch_size)
grads = tape.gradient(loss, model.trainable_variables)
gradient_accumulator(grads)
return cross_entropy
per_example_losses = strategy.experimental_run_v2(step_fn, args=(train_features, train_labels))
mean_loss = strategy.reduce(tf.distribute.ReduceOp.MEAN, per_example_losses, axis=0)
return mean_loss
current_time = datetime.datetime.now()
train_iterator = master_bar(range(args['num_train_epochs']))
global_step = 0
logging_loss = 0.0
for epoch in train_iterator:
epoch_iterator = progress_bar(train_dataset, total=num_train_steps, parent=train_iterator, display=args['n_device'] > 1)
step = 1
with strategy.scope():
for train_features, train_labels in epoch_iterator:
loss = train_step(train_features, train_labels)
if step % args['gradient_accumulation_steps'] == 0:
strategy.experimental_run_v2(apply_gradients)
loss_metric(loss)
global_step += 1
if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
# Log metrics
if args['n_device'] == 1 and args['evaluate_during_training']: # Only evaluate when single GPU otherwise metrics may not average well
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
report = metrics.classification_report(y_true, y_pred, digits=4)
logging.info("Eval at step " + str(global_step) + "\n" + report)
logging.info("eval_loss: " + str(eval_loss))
precision = metrics.precision_score(y_true, y_pred)
recall = metrics.recall_score(y_true, y_pred)
f1 = metrics.f1_score(y_true, y_pred)
with writer.as_default():
tf.summary.scalar("eval_loss", eval_loss, global_step)
tf.summary.scalar("precision", precision, global_step)
tf.summary.scalar("recall", recall, global_step)
tf.summary.scalar("f1", f1, global_step)
lr = optimizer.learning_rate
learning_rate = lr(step)
with writer.as_default():
tf.summary.scalar("lr", learning_rate, global_step)
tf.summary.scalar("loss", (loss_metric.result() - logging_loss) / args['logging_steps'], global_step)
logging_loss = loss_metric.result()
with writer.as_default():
tf.summary.scalar("loss", loss_metric.result(), step=step)
if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
# Save model checkpoint
output_dir = os.path.join(args['output_dir'], "checkpoint-{}".format(global_step))
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
logging.info("Saving model checkpoint to %s", output_dir)
train_iterator.child.comment = f'loss : {loss_metric.result()}'
step += 1
train_iterator.write(f'loss epoch {epoch + 1}: {loss_metric.result()}')
loss_metric.reset_states()
logging.info(" Training took time = {}".format(datetime.datetime.now() - current_time))
def evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode):
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
eval_dataset, size = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode=mode)
eval_dataset = strategy.experimental_distribute_dataset(eval_dataset)
preds = None
num_eval_steps = math.ceil(size / eval_batch_size)
master = master_bar(range(1))
eval_iterator = progress_bar(eval_dataset, total=num_eval_steps, parent=master, display=args['n_device'] > 1)
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
loss = 0.0
logging.info("***** Running evaluation *****")
logging.info(" Num examples = %d", size)
logging.info(" Batch size = %d", eval_batch_size)
for eval_features, eval_labels in eval_iterator:
inputs = {'attention_mask': eval_features['input_mask'], 'training': False}
if args['model_type'] != "distilbert":
inputs["token_type_ids"] = eval_features['segment_ids'] if args['model_type'] in ["bert", "xlnet"] else None
with strategy.scope():
logits = model(eval_features['input_ids'], **inputs)[0]
tmp_logits = tf.reshape(logits, (-1, len(labels) + 1))
active_loss = tf.reshape(eval_features['input_mask'], (-1,))
active_logits = tf.boolean_mask(tmp_logits, active_loss)
tmp_eval_labels = tf.reshape(eval_labels, (-1,))
active_labels = tf.boolean_mask(tmp_eval_labels, active_loss)
cross_entropy = loss_fct(active_labels, active_logits)
loss += tf.reduce_sum(cross_entropy) * (1.0 / eval_batch_size)
if preds is None:
preds = logits.numpy()
label_ids = eval_labels.numpy()
else:
preds = np.append(preds, logits.numpy(), axis=0)
label_ids = np.append(label_ids, eval_labels.numpy(), axis=0)
preds = np.argmax(preds, axis=2)
y_pred = [[] for _ in range(label_ids.shape[0])]
y_true = [[] for _ in range(label_ids.shape[0])]
loss = loss / num_eval_steps
for i in range(label_ids.shape[0]):
for j in range(label_ids.shape[1]):
if label_ids[i, j] != pad_token_label_id:
y_pred[i].append(labels[preds[i, j] - 1])
y_true[i].append(labels[label_ids[i, j] - 1])
return y_true, y_pred, loss.numpy()
def load_cache(cached_file, max_seq_length):
name_to_features = {
"input_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"input_mask": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"segment_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
"label_ids": tf.io.FixedLenFeature([max_seq_length], tf.int64),
}
def _decode_record(record):
example = tf.io.parse_single_example(record, name_to_features)
features = {}
features['input_ids'] = example['input_ids']
features['input_mask'] = example['input_mask']
features['segment_ids'] = example['segment_ids']
return features, example['label_ids']
d = tf.data.TFRecordDataset(cached_file)
d = d.map(_decode_record, num_parallel_calls=4)
count = d.reduce(0, lambda x, _: x + 1)
return d, count.numpy()
def save_cache(features, cached_features_file):
writer = tf.io.TFRecordWriter(cached_features_file)
for (ex_index, feature) in enumerate(features):
if ex_index % 5000 == 0:
logging.info("Writing example %d of %d" % (ex_index, len(features)))
def create_int_feature(values):
f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values)))
return f
record_feature = collections.OrderedDict()
record_feature["input_ids"] = create_int_feature(feature.input_ids)
record_feature["input_mask"] = create_int_feature(feature.input_mask)
record_feature["segment_ids"] = create_int_feature(feature.segment_ids)
record_feature["label_ids"] = create_int_feature(feature.label_ids)
tf_example = tf.train.Example(features=tf.train.Features(feature=record_feature))
writer.write(tf_example.SerializeToString())
writer.close()
def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, batch_size, mode):
drop_remainder = True if args['tpu'] or mode == 'train' else False
# Load data features from cache or dataset file
cached_features_file = os.path.join(args['data_dir'], "cached_{}_{}_{}.tf_record".format(mode,
list(filter(None, args['model_name_or_path'].split("/"))).pop(),
str(args['max_seq_length'])))
if os.path.exists(cached_features_file) and not args['overwrite_cache']:
logging.info("Loading features from cached file %s", cached_features_file)
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
else:
logging.info("Creating features from dataset file at %s", args['data_dir'])
examples = read_examples_from_file(args['data_dir'], mode)
features = convert_examples_to_features(examples, labels, args['max_seq_length'], tokenizer,
cls_token_at_end=bool(args['model_type'] in ["xlnet"]),
# xlnet has a cls token at the end
cls_token=tokenizer.cls_token,
cls_token_segment_id=2 if args['model_type'] in ["xlnet"] else 0,
sep_token=tokenizer.sep_token,
sep_token_extra=bool(args['model_type'] in ["roberta"]),
# roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
pad_on_left=bool(args['model_type'] in ["xlnet"]),
# pad on the left for xlnet
pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
pad_token_segment_id=4 if args['model_type'] in ["xlnet"] else 0,
pad_token_label_id=pad_token_label_id
)
logging.info("Saving features into cached file %s", cached_features_file)
save_cache(features, cached_features_file)
dataset, size = load_cache(cached_features_file, args['max_seq_length'])
if mode == 'train':
dataset = dataset.repeat()
dataset = dataset.shuffle(buffer_size=8192, seed=args['seed'])
dataset = dataset.batch(batch_size, drop_remainder)
dataset = dataset.prefetch(buffer_size=batch_size)
return dataset, size
def main(_):
logging.set_verbosity(logging.INFO)
args = flags.FLAGS.flag_values_dict()
if os.path.exists(args['output_dir']) and os.listdir(
args['output_dir']) and args['do_train'] and not args['overwrite_output_dir']:
raise ValueError(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
args['output_dir']))
if args['fp16']:
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
if args['tpu']:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=args['tpu'])
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.experimental.TPUStrategy(resolver)
args['n_device'] = args['num_tpu_cores']
elif len(args['gpus'].split(',')) > 1:
args['n_device'] = len([f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
strategy = tf.distribute.MirroredStrategy(devices=[f"/gpu:{gpu}" for gpu in args['gpus'].split(',')])
elif args['no_cuda']:
args['n_device'] = 1
strategy = tf.distribute.OneDeviceStrategy(device="/cpu:0")
else:
args['n_device'] = len(args['gpus'].split(','))
strategy = tf.distribute.OneDeviceStrategy(device="/gpu:" + args['gpus'].split(',')[0])
logging.warning("n_device: %s, distributed training: %s, 16-bits training: %s",
args['n_device'], bool(args['n_device'] > 1), args['fp16'])
labels = get_labels(args['labels'])
num_labels = len(labels) + 1
pad_token_label_id = 0
config_class, model_class, tokenizer_class = MODEL_CLASSES[args['model_type']]
config = config_class.from_pretrained(args['config_name'] if args['config_name'] else args['model_name_or_path'],
num_labels=num_labels,
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
logging.info("Training/evaluation parameters %s", args)
# Training
if args['do_train']:
tokenizer = tokenizer_class.from_pretrained(args['tokenizer_name'] if args['tokenizer_name'] else args['model_name_or_path'],
do_lower_case=args['do_lower_case'],
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
with strategy.scope():
model = model_class.from_pretrained(args['model_name_or_path'],
from_pt=bool(".bin" in args['model_name_or_path']),
config=config,
cache_dir=args['cache_dir'] if args['cache_dir'] else None)
model.layers[-1].activation = tf.keras.activations.softmax
train_batch_size = args['per_device_train_batch_size'] * args['n_device']
train_dataset, num_train_examples = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, train_batch_size, mode="train")
train_dataset = strategy.experimental_distribute_dataset(train_dataset)
train(args, strategy, train_dataset, tokenizer, model, num_train_examples, labels, train_batch_size, pad_token_label_id)
if not os.path.exists(args['output_dir']):
os.makedirs(args['output_dir'])
logging.info("Saving model to %s", args['output_dir'])
model.save_pretrained(args['output_dir'])
tokenizer.save_pretrained(args['output_dir'])
# Evaluation
if args['do_eval']:
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
checkpoints = []
results = []
if args['eval_all_checkpoints']:
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args['output_dir'] + "/**/" + TF2_WEIGHTS_NAME, recursive=True), key=lambda f: int(''.join(filter(str.isdigit, f)) or -1)))
logging.info("Evaluate the following checkpoints: %s", checkpoints)
if len(checkpoints) == 0:
checkpoints.append(args['output_dir'])
for checkpoint in checkpoints:
global_step = checkpoint.split("-")[-1] if re.match(".*checkpoint-[0-9]", checkpoint) else "final"
with strategy.scope():
model = model_class.from_pretrained(checkpoint)
y_true, y_pred, eval_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="dev")
report = metrics.classification_report(y_true, y_pred, digits=4)
if global_step:
results.append({global_step + "_report": report, global_step + "_loss": eval_loss})
output_eval_file = os.path.join(args['output_dir'], "eval_results.txt")
with tf.io.gfile.GFile(output_eval_file, "w") as writer:
for res in results:
for key, val in res.items():
if "loss" in key:
logging.info(key + " = " + str(val))
writer.write(key + " = " + str(val))
writer.write("\n")
else:
logging.info(key)
logging.info("\n" + report)
writer.write(key + "\n")
writer.write(report)
writer.write("\n")
if args['do_predict']:
tokenizer = tokenizer_class.from_pretrained(args['output_dir'], do_lower_case=args['do_lower_case'])
model = model_class.from_pretrained(args['output_dir'])
eval_batch_size = args['per_device_eval_batch_size'] * args['n_device']
predict_dataset, _ = load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, eval_batch_size, mode="test")
y_true, y_pred, pred_loss = evaluate(args, strategy, model, tokenizer, labels, pad_token_label_id, mode="test")
output_test_results_file = os.path.join(args['output_dir'], "test_results.txt")
output_test_predictions_file = os.path.join(args['output_dir'], "test_predictions.txt")
report = metrics.classification_report(y_true, y_pred, digits=4)
with tf.io.gfile.GFile(output_test_results_file, "w") as writer:
report = metrics.classification_report(y_true, y_pred, digits=4)
logging.info("\n" + report)
writer.write(report)
writer.write("\n\nloss = " + str(pred_loss))
with tf.io.gfile.GFile(output_test_predictions_file, "w") as writer:
with tf.io.gfile.GFile(os.path.join(args['data_dir'], "test.txt"), "r") as f:
example_id = 0
for line in f:
if line.startswith("-DOCSTART-") or line == "" or line == "\n":
writer.write(line)
if not y_pred[example_id]:
example_id += 1
elif y_pred[example_id]:
output_line = line.split()[0] + " " + y_pred[example_id].pop(0) + "\n"
writer.write(output_line)
else:
logging.warning("Maximum sequence length exceeded: No prediction for '%s'.", line.split()[0])
if __name__ == "__main__":
flags.mark_flag_as_required("data_dir")
flags.mark_flag_as_required("output_dir")
flags.mark_flag_as_required("model_name_or_path")
flags.mark_flag_as_required("model_type")
app.run(main)

View File

@ -164,6 +164,7 @@ if is_tf_available():
from .modeling_tf_distilbert import (TFDistilBertPreTrainedModel, TFDistilBertMainLayer,
TFDistilBertModel, TFDistilBertForMaskedLM,
TFDistilBertForSequenceClassification,
TFDistilBertForTokenClassification,
TFDistilBertForQuestionAnswering,
TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
@ -174,6 +175,8 @@ if is_tf_available():
from .modeling_tf_albert import (TFAlbertPreTrainedModel, TFAlbertModel, TFAlbertForMaskedLM,
TFAlbertForSequenceClassification,
TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
# Optimization
from .optimization_tf import (WarmUp, create_optimizer, AdamWeightDecay, GradientAccumulator)
# TF 2.0 <=> PyTorch conversion utilities
from .modeling_tf_pytorch_utils import (convert_tf_weight_name_to_pt_weight_name,

View File

@ -704,6 +704,53 @@ class TFDistilBertForSequenceClassification(TFDistilBertPreTrainedModel):
return outputs # logits, (hidden_states), (attentions)
@add_start_docstrings("""DistilBert Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)
class TFDistilBertForTokenClassification(TFDistilBertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import DistilBertTokenizer, TFDistilBertForTokenClassification
tokenizer = DistilBertTokenizer.from_pretrained('bert-base-uncased')
model = TFDistilBertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
"""
def __init__(self, config, *inputs, **kwargs):
super(TFDistilBertForTokenClassification, self).__init__(config, *inputs, **kwargs)
self.num_labels = config.num_labels
self.distilbert = TFDistilBertMainLayer(config, name='distilbert')
self.dropout = tf.keras.layers.Dropout(config.dropout)
self.classifier = tf.keras.layers.Dense(config.num_labels,
kernel_initializer=get_initializer(config.initializer_range),
name='classifier')
def call(self, inputs, **kwargs):
outputs = self.distilbert(inputs, **kwargs)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output, training=kwargs.get('training', False))
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
return outputs # scores, (hidden_states), (attentions)
@add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
DISTILBERT_START_DOCSTRING, DISTILBERT_INPUTS_DOCSTRING)

View File

@ -0,0 +1,254 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
import tensorflow as tf
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
"""Applys a warmup schedule on a given learning rate decay schedule."""
def __init__(
self,
initial_learning_rate,
decay_schedule_fn,
warmup_steps,
power=1.0,
name=None):
super(WarmUp, self).__init__()
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.power = power
self.decay_schedule_fn = decay_schedule_fn
self.name = name
def __call__(self, step):
with tf.name_scope(self.name or 'WarmUp') as name:
# Implements polynomial warmup. i.e., if global_step < warmup_steps, the
# learning rate will be `global_step/num_warmup_steps * init_lr`.
global_step_float = tf.cast(step, tf.float32)
warmup_steps_float = tf.cast(self.warmup_steps, tf.float32)
warmup_percent_done = global_step_float / warmup_steps_float
warmup_learning_rate = (
self.initial_learning_rate *
tf.math.pow(warmup_percent_done, self.power))
return tf.cond(global_step_float < warmup_steps_float,
lambda: warmup_learning_rate,
lambda: self.decay_schedule_fn(step),
name=name)
def get_config(self):
return {
'initial_learning_rate': self.initial_learning_rate,
'decay_schedule_fn': self.decay_schedule_fn,
'warmup_steps': self.warmup_steps,
'power': self.power,
'name': self.name
}
def create_optimizer(init_lr, num_train_steps, num_warmup_steps):
"""Creates an optimizer with learning rate schedule."""
# Implements linear decay of the learning rate.
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=init_lr,
decay_steps=num_train_steps,
end_learning_rate=0.0)
if num_warmup_steps:
learning_rate_fn = WarmUp(initial_learning_rate=init_lr,
decay_schedule_fn=learning_rate_fn,
warmup_steps=num_warmup_steps)
optimizer = AdamWeightDecay(
learning_rate=learning_rate_fn,
weight_decay_rate=0.01,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=['layer_norm', 'bias'])
return optimizer
class AdamWeightDecay(tf.keras.optimizers.Adam):
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
Just adding the square of the weights to the loss function is *not* the
correct way of using L2 regularization/weight decay with Adam, since that will
interact with the m and v parameters in strange ways.
Instead we want ot decay the weights in a manner that doesn't interact with
the m/v parameters. This is equivalent to adding the square of the weights to
the loss with plain (non-momentum) SGD.
"""
def __init__(self,
learning_rate=0.001,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay
@classmethod
def from_config(cls, config):
"""Creates an optimizer from its config with WarmUp custom object."""
custom_objects = {'WarmUp': WarmUp}
return super(AdamWeightDecay, cls).from_config(
config, custom_objects=custom_objects)
def _prepare_local(self, var_device, var_dtype, apply_state):
super(AdamWeightDecay, self)._prepare_local(var_device, var_dtype,
apply_state)
apply_state['weight_decay_rate'] = tf.constant(
self.weight_decay_rate, name='adam_weight_decay_rate')
def _decay_weights_op(self, var, learning_rate, apply_state):
do_decay = self._do_use_weight_decay(var.name)
if do_decay:
return var.assign_sub(
learning_rate * var *
apply_state['weight_decay_rate'],
use_locking=self._use_locking)
return tf.no_op()
def apply_gradients(self, grads_and_vars, clip_norm, name=None):
grads, tvars = list(zip(*grads_and_vars))
(grads, _) = tf.clip_by_global_norm(grads, clip_norm=clip_norm)
return super(AdamWeightDecay, self).apply_gradients(zip(grads, tvars))
def _get_lr(self, var_device, var_dtype, apply_state):
"""Retrieves the learning rate with the given state."""
if apply_state is None:
return self._decayed_lr_t[var_dtype], {}
apply_state = apply_state or {}
coefficients = apply_state.get((var_device, var_dtype))
if coefficients is None:
coefficients = self._fallback_apply_state(var_device, var_dtype)
apply_state[(var_device, var_dtype)] = coefficients
return coefficients['lr_t'], dict(apply_state=apply_state)
def _resource_apply_dense(self, grad, var, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var, **kwargs)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
lr_t, kwargs = self._get_lr(var.device, var.dtype.base_dtype, apply_state)
decay = self._decay_weights_op(var, lr_t, apply_state)
with tf.control_dependencies([decay]):
return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices, **kwargs)
def get_config(self):
config = super(AdamWeightDecay, self).get_config()
config.update({
'weight_decay_rate': self.weight_decay_rate,
})
return config
def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False
if self._include_in_weight_decay:
for r in self._include_in_weight_decay:
if re.search(r, param_name) is not None:
return True
if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
## Inspired from https://github.com/OpenNMT/OpenNMT-tf/blob/master/opennmt/optimizers/utils.py
class GradientAccumulator(object):
"""Distribution strategies-aware gradient accumulation utility."""
def __init__(self):
"""Initializes the accumulator."""
self._gradients = []
self._accum_steps = tf.Variable(
initial_value=0,
dtype=tf.int64,
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
@property
def step(self):
"""Number of accumulated steps."""
return self._accum_steps.value()
@property
def gradients(self):
"""The accumulated gradients."""
return list(gradient.value() if gradient is not None else gradient for gradient in self._get_replica_gradients())
def __call__(self, gradients):
"""Accumulates :obj:`gradients`."""
if not self._gradients:
self._gradients.extend([tf.Variable(tf.zeros_like(gradient), trainable=False) if gradient is not None else gradient for gradient in gradients])
if len(gradients) != len(self._gradients):
raise ValueError("Expected %s gradients, but got %d" % (len(self._gradients), len(gradients)))
for accum_gradient, gradient in zip(self._get_replica_gradients(), gradients):
if accum_gradient is not None:
accum_gradient.assign_add(gradient)
self._accum_steps.assign_add(1)
def reset(self):
"""Resets the accumulated gradients."""
if self._gradients:
self._accum_steps.assign(0)
for gradient in self._get_replica_gradients():
if gradient is not None:
gradient.assign(tf.zeros_like(gradient))
def _get_replica_gradients(self):
if tf.distribute.has_strategy():
# In a replica context, we want to accumulate gradients on each replica
# without synchronization, so we directly assign the value of the
# current replica.
replica_context = tf.distribute.get_replica_context()
if replica_context is None or tf.distribute.get_strategy().num_replicas_in_sync == 1:
return self._gradients
return (gradient.device_map.select_for_current_replica(gradient.values, replica_context) for gradient in self._gradients)
else:
return self._gradients

View File

@ -0,0 +1,89 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import unittest
import pytest
from transformers import is_tf_available
if is_tf_available():
import tensorflow as tf
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from transformers import (create_optimizer, GradientAccumulator)
else:
pytestmark = pytest.mark.skip("Require TensorFlow")
class OptimizationFTest(unittest.TestCase):
def assertListAlmostEqual(self, list1, list2, tol):
self.assertEqual(len(list1), len(list2))
for a, b in zip(list1, list2):
self.assertAlmostEqual(a, b, delta=tol)
def testGradientAccumulator(self):
accumulator = GradientAccumulator()
accumulator([tf.constant([1.0, 2.0])])
accumulator([tf.constant([-2.0, 1.0])])
accumulator([tf.constant([-1.0, 2.0])])
with self.assertRaises(ValueError):
accumulator([tf.constant([1.0, 1.0]), tf.constant([2.0, 2.0])])
self.assertEqual(accumulator.step, 3)
self.assertEqual(len(accumulator.gradients), 1)
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [-2.0, 5.0], tol=1e-2)
accumulator.reset()
self.assertEqual(accumulator.step, 0)
self.assertListAlmostEqual(accumulator.gradients[0].numpy().tolist(), [0.0, 0.0], tol=1e-2)
def testGradientAccumulatorDistributionStrategy(self):
context._context = None
ops.enable_eager_execution_internal()
physical_devices = tf.config.experimental.list_physical_devices("CPU")
tf.config.experimental.set_virtual_device_configuration(
physical_devices[0],
[tf.config.experimental.VirtualDeviceConfiguration(),
tf.config.experimental.VirtualDeviceConfiguration()])
devices = tf.config.experimental.list_logical_devices(device_type="CPU")
strategy = tf.distribute.MirroredStrategy(devices=[device.name for device in devices])
with strategy.scope():
accumulator = GradientAccumulator()
variable = tf.Variable([4.0, 3.0])
optimizer = create_optimizer(5e-5, 10, 5)
gradient_placeholder = tf.Variable([0.0, 0.0], trainable=False)
def accumulate_on_replica(gradient):
accumulator([gradient])
def apply_on_replica():
optimizer.apply_gradients(list(zip(accumulator.gradients, [variable])), 1.0)
@tf.function
def accumulate(grad1, grad2):
with strategy.scope():
gradient_placeholder.values[0].assign(grad1)
gradient_placeholder.values[1].assign(grad2)
strategy.experimental_run_v2(accumulate_on_replica, args=(gradient_placeholder,))
@tf.function
def apply_grad():
with strategy.scope():
strategy.experimental_run_v2(apply_on_replica)
accumulate([1.0, 2.0], [-1.0, 1.0])
accumulate([3.0, -1.0], [-1.0, -1.0])
accumulate([-2.0, 2.0], [3.0, -2.0])
self.assertEqual(accumulator.step, 3)
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [2.0, 3.0], tol=1e-2)
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [1.0, -2.0], tol=1e-2)
apply_grad()
self.assertListAlmostEqual(variable.value().numpy().tolist(), [4.0, 3.0], tol=1e-2)
accumulator.reset()
self.assertEqual(accumulator.step, 0)
self.assertListAlmostEqual(accumulator._gradients[0].values[0].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
self.assertListAlmostEqual(accumulator._gradients[0].values[1].value().numpy().tolist(), [0.0, 0.0], tol=1e-2)
if __name__ == "__main__":
unittest.main()