[Model card] Bert2Bert

Add Rouge2 results
This commit is contained in:
Patrick von Platen 2020-07-17 18:22:05 +02:00 committed by GitHub
parent 9d37c56bab
commit 12f14710ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 8 deletions

View File

@ -125,12 +125,10 @@ def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = tokenizer.batch_decode(pred_ids, clean_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, clean_special_tokens=True)
pred_str = [pred.split("[CLS]")[-1].split("[SEP]")[0] for pred in pred_str]
label_str = [label.split("[CLS]")[-1].split("[SEP]")[0] for label in label_str]
# all unnecessary tokens are removed
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
return {
@ -189,6 +187,43 @@ trainer = Trainer(
trainer.train()
```
## Results
## Evaluation
TODO
The following script evaluates the model on the test set of
CNN/Daily Mail.
```python
#!/usr/bin/env python3
import nlp
from transformers import BertTokenizer, EncoderDecoderModel
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
model.to("cuda")
test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")
batch_size = 128
# map data correctly
def generate_summary(batch):
# Tokenizer will automatically set [BOS] <text> [EOS]
# cut off at BERT max length 512
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
input_ids = inputs.input_ids.to("cuda")
attention_mask = inputs.attention_mask.to("cuda")
outputs = model.generate(input_ids, attention_mask=attention_mask)
# all special tokens including will be removed
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
batch["pred"] = output_str
return batch
results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
# load rouge for validation
rouge = nlp.load_metric("rouge")
pred_str = results["pred"]
label_str = results["highlights"]
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
print(rouge_output)
```
The obtained results should be:
| - | Rouge2 - mid -precision | Rouge2 - mid - recall | Rouge2 - mid - fmeasure |
|----------|:-------------:|:------:|:------:|
| **CNN/Daily Mail** | 14.12 | 14.37 | **13.8** |