NER: Add new WNUT’17 example (#4681)

* ner: add preprocessing script for examples that splits longer sentences

* ner: example shell scripts use local preprocessing now

* ner: add new example section for WNUT’17 NER task. Remove old English CoNLL-03 results

* ner: satisfy black and isort
This commit is contained in:
Stefan Schweter 2020-06-05 01:13:17 +02:00 committed by GitHub
parent 0e1869cc28
commit 2a4b9e09c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 167 additions and 31 deletions

View File

@ -2,10 +2,17 @@
Based on the scripts [`run_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_ner.py) for Pytorch and
[`run_tf_ner.py`](https://github.com/huggingface/transformers/blob/master/examples/token-classification/run_tf_ner.py) for Tensorflow 2.
This example fine-tune Bert Multilingual on GermEval 2014 (German NER).
The following examples are covered in this section:
* NER on the GermEval 2014 (German NER) dataset
* Emerging and Rare Entities task: WNUT17 (English NER) dataset
Details and results for the fine-tuning provided by @stefan-it.
### Data (Download and pre-processing steps)
### GermEval 2014 (German NER) dataset
#### Data (Download and pre-processing steps)
Data can be obtained from the [GermEval 2014](https://sites.google.com/site/germeval2014ner/data) shared task page.
@ -20,11 +27,10 @@ curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attr
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
```
The GermEval 2014 dataset contains some strange "control character" tokens like `'\x96', '\u200e', '\x95', '\xad' or '\x80'`. One problem with these tokens is, that `BertTokenizer` returns an empty token for them, resulting in misaligned `InputExample`s. I wrote a script that a) filters these tokens and b) splits longer sentences into smaller ones (once the max. subtoken length is reached).
The GermEval 2014 dataset contains some strange "control character" tokens like `'\x96', '\u200e', '\x95', '\xad' or '\x80'`.
One problem with these tokens is, that `BertTokenizer` returns an empty token for them, resulting in misaligned `InputExample`s.
The `preprocess.py` script located in the `scripts` folder a) filters these tokens and b) splits longer sentences into smaller ones (once the max. subtoken length is reached).
```bash
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
```
Let's define some variables that we need for further pre-processing steps and training the model:
```bash
@ -35,9 +41,9 @@ export BERT_MODEL=bert-base-multilingual-cased
Run the pre-processing script on training, dev and test datasets:
```bash
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
python3 scripts/preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 scripts/preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 scripts/preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
```
The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so an own set of labels must be used:
@ -46,7 +52,7 @@ The GermEval 2014 dataset has much more labels than CoNLL-2002/2003 datasets, so
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
```
### Prepare the run
#### Prepare the run
Additional environment variables must be set:
@ -58,7 +64,7 @@ export SAVE_STEPS=750
export SEED=1
```
### Run the Pytorch version
#### Run the Pytorch version
To start training, just run:
@ -79,7 +85,7 @@ python3 run_ner.py --data_dir ./ \
If your GPU supports half-precision training, just add the `--fp16` flag. After training, the model will be both evaluated on development and test datasets.
### JSON-based configuration file
#### JSON-based configuration file
Instead of passing all parameters via commandline arguments, the `run_ner.py` script also supports reading parameters from a json-based configuration file:
@ -124,16 +130,6 @@ On the test dataset the following results could be achieved:
10/04/2019 00:42:42 - INFO - __main__ - recall = 0.8624150210424085
```
#### Comparing BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased)
Here is a small comparison between BERT (large, cased), RoBERTa (large, cased) and DistilBERT (base, uncased) with the same hyperparameters as specified in the [example documentation](https://huggingface.co/transformers/examples.html#named-entity-recognition) (one run):
| Model | F-Score Dev | F-Score Test
| --------------------------------- | ------- | --------
| `bert-large-cased` | 95.59 | 91.70
| `roberta-large` | 95.96 | 91.87
| `distilbert-base-uncased` | 94.34 | 90.32
#### Run PyTorch version using PyTorch-Lightning
Run `bash run_pl.sh` from the `ner` directory. This would also install `pytorch-lightning` and the `examples/requirements.txt`. It is a shell pipeline which would automatically download, pre-process the data and run the models in `germeval-model` directory. Logs are saved in `lightning_logs` directory.
@ -141,7 +137,7 @@ Run `bash run_pl.sh` from the `ner` directory. This would also install `pytorch-
Pass `--n_gpu` flag to change the number of GPUs. Default uses 1. At the end, the expected results are: `TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recall': 0.869537067011978, 'f1': 0.8608974358974358}`
### Run the Tensorflow 2 version
#### Run the Tensorflow 2 version
To start training, just run:
@ -205,3 +201,102 @@ On the test dataset the following results could be achieved:
micro avg 0.8722 0.8774 0.8748 13869
macro avg 0.8712 0.8774 0.8740 13869
```
### Emerging and Rare Entities task: WNUT17 (English NER) dataset
Description of the WNUT17 task from the [shared task website](http://noisy-text.github.io/2017/index.html):
> The WNUT17 shared task focuses on identifying unusual, previously-unseen entities in the context of emerging discussions.
> Named entities form the basis of many modern approaches to other tasks (like event clustering and summarization), but recall on
> them is a real problem in noisy text - even among annotators. This drop tends to be due to novel entities and surface forms.
Six labels are available in the dataset. An overview can be found on this [page](http://noisy-text.github.io/2017/files/).
#### Data (Download and pre-processing steps)
The dataset can be downloaded from the [official GitHub](https://github.com/leondz/emerging_entities_17) repository.
The following commands show how to prepare the dataset for fine-tuning:
```bash
mkdir -p data_wnut_17
curl -L 'https://github.com/leondz/emerging_entities_17/raw/master/wnut17train.conll' | tr '\t' ' ' > data_wnut_17/train.txt.tmp
curl -L 'https://github.com/leondz/emerging_entities_17/raw/master/emerging.dev.conll' | tr '\t' ' ' > data_wnut_17/dev.txt.tmp
curl -L 'https://raw.githubusercontent.com/leondz/emerging_entities_17/master/emerging.test.annotated' | tr '\t' ' ' > data_wnut_17/test.txt.tmp
```
Let's define some variables that we need for further pre-processing steps:
```bash
export MAX_LENGTH=128
export BERT_MODEL=bert-large-cased
```
Here we use the English BERT large model for fine-tuning.
The `preprocess.py` scripts splits longer sentences into smaller ones (once the max. subtoken length is reached):
```bash
python3 scripts/preprocess.py data_wnut_17/train.txt.tmp $BERT_MODEL $MAX_LENGTH > data_wnut_17/train.txt
python3 scripts/preprocess.py data_wnut_17/dev.txt.tmp $BERT_MODEL $MAX_LENGTH > data_wnut_17/dev.txt
python3 scripts/preprocess.py data_wnut_17/test.txt.tmp $BERT_MODEL $MAX_LENGTH > data_wnut_17/test.txt
```
In the last pre-processing step, the `labels.txt` file needs to be generated. This file contains all available labels:
```bash
cat data_wnut_17/train.txt data_wnut_17/dev.txt data_wnut_17/test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > data_wnut_17/labels.txt
```
#### Run the Pytorch version
Fine-tuning with the PyTorch version can be started using the `run_ner.py` script. In this example we use a JSON-based configuration file.
This configuration file looks like:
```json
{
"data_dir": "./data_wnut_17",
"labels": "./data_wnut_17/labels.txt",
"model_name_or_path": "bert-large-cased",
"output_dir": "wnut-17-model-1",
"max_seq_length": 128,
"num_train_epochs": 3,
"per_device_train_batch_size": 32,
"save_steps": 425,
"seed": 1,
"do_train": true,
"do_eval": true,
"do_predict": true,
"fp16": false
}
```
If your GPU supports half-precision training, please set `fp16` to `true`.
Save this JSON-based configuration under `wnut_17.json`. The fine-tuning can be started with `python3 run_ner.py wnut_17.json`.
#### Evaluation
Evaluation on development dataset outputs the following:
```bash
05/29/2020 23:33:44 - INFO - __main__ - ***** Eval results *****
05/29/2020 23:33:44 - INFO - __main__ - eval_loss = 0.26505235286212275
05/29/2020 23:33:44 - INFO - __main__ - eval_precision = 0.7008264462809918
05/29/2020 23:33:44 - INFO - __main__ - eval_recall = 0.507177033492823
05/29/2020 23:33:44 - INFO - __main__ - eval_f1 = 0.5884802220680084
05/29/2020 23:33:44 - INFO - __main__ - epoch = 3.0
```
On the test dataset the following results could be achieved:
```bash
05/29/2020 23:33:44 - INFO - transformers.trainer - ***** Running Prediction *****
05/29/2020 23:34:02 - INFO - __main__ - eval_loss = 0.30948806500973547
05/29/2020 23:34:02 - INFO - __main__ - eval_precision = 0.5840108401084011
05/29/2020 23:34:02 - INFO - __main__ - eval_recall = 0.3994439295644115
05/29/2020 23:34:02 - INFO - __main__ - eval_f1 = 0.47440836543753434
```
WNUT17 is a very difficult task. Current state-of-the-art results on this dataset can be found [here](http://nlpprogress.com/english/named_entity_recognition.html).

View File

@ -4,12 +4,12 @@ curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attre
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
export MAX_LENGTH=128
export BERT_MODEL=bert-base-multilingual-cased
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
python3 scripts/preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 scripts/preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 scripts/preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
export OUTPUT_DIR=germeval-model
export BATCH_SIZE=32

View File

@ -11,12 +11,12 @@ curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-dev.tsv?attre
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > dev.txt.tmp
curl -L 'https://sites.google.com/site/germeval2014ner/data/NER-de-test.tsv?attredirects=0&d=1' \
| grep -v "^#" | cut -f 2,3 | tr '\t' ' ' > test.txt.tmp
wget "https://raw.githubusercontent.com/stefan-it/fine-tuned-berts-seq/master/scripts/preprocess.py"
export MAX_LENGTH=128
export BERT_MODEL=bert-base-multilingual-cased
python3 preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
python3 scripts/preprocess.py train.txt.tmp $BERT_MODEL $MAX_LENGTH > train.txt
python3 scripts/preprocess.py dev.txt.tmp $BERT_MODEL $MAX_LENGTH > dev.txt
python3 scripts/preprocess.py test.txt.tmp $BERT_MODEL $MAX_LENGTH > test.txt
cat train.txt dev.txt test.txt | cut -d " " -f 2 | grep -v "^$"| sort | uniq > labels.txt
export BATCH_SIZE=32
export NUM_EPOCHS=3

View File

@ -0,0 +1,41 @@
import sys
from transformers import AutoTokenizer
dataset = sys.argv[1]
model_name_or_path = sys.argv[2]
max_len = int(sys.argv[3])
subword_len_counter = 0
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
max_len -= tokenizer.num_special_tokens_to_add()
with open(dataset, "rt") as f_p:
for line in f_p:
line = line.rstrip()
if not line:
print(line)
subword_len_counter = 0
continue
token = line.split()[0]
current_subwords_len = len(tokenizer.tokenize(token))
# Token contains strange control characters like \x96 or \x95
# Just filter out the complete line
if current_subwords_len == 0:
continue
if (subword_len_counter + current_subwords_len) > max_len:
print("")
print(line)
subword_len_counter = current_subwords_len
continue
subword_len_counter += current_subwords_len
print(line)