diff --git a/model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md b/model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md index a873aa0710..8144459fa4 100644 --- a/model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md +++ b/model_cards/patrickvonplaten/bert2bert-cnn_dailymail-fp16/README.md @@ -125,12 +125,10 @@ def compute_metrics(pred): labels_ids = pred.label_ids pred_ids = pred.predictions - pred_str = tokenizer.batch_decode(pred_ids, clean_special_tokens=True) - label_str = tokenizer.batch_decode(labels_ids, clean_special_tokens=True) - - pred_str = [pred.split("[CLS]")[-1].split("[SEP]")[0] for pred in pred_str] - label_str = [label.split("[CLS]")[-1].split("[SEP]")[0] for label in label_str] - + # all unnecessary tokens are removed + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) + rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid return { @@ -189,6 +187,43 @@ trainer = Trainer( trainer.train() ``` -## Results +## Evaluation -TODO +The following script evaluates the model on the test set of +CNN/Daily Mail. + +```python +#!/usr/bin/env python3 +import nlp +from transformers import BertTokenizer, EncoderDecoderModel +tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") +model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16") +model.to("cuda") +test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test") +batch_size = 128 +# map data correctly +def generate_summary(batch): + # Tokenizer will automatically set [BOS] [EOS] + # cut off at BERT max length 512 + inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt") + input_ids = inputs.input_ids.to("cuda") + attention_mask = inputs.attention_mask.to("cuda") + outputs = model.generate(input_ids, attention_mask=attention_mask) + # all special tokens including will be removed + output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True) + batch["pred"] = output_str + return batch +results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"]) +# load rouge for validation +rouge = nlp.load_metric("rouge") +pred_str = results["pred"] +label_str = results["highlights"] +rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid +print(rouge_output) +``` + +The obtained results should be: + +| - | Rouge2 - mid -precision | Rouge2 - mid - recall | Rouge2 - mid - fmeasure | +|----------|:-------------:|:------:|:------:| +| **CNN/Daily Mail** | 14.12 | 14.37 | **13.8** |