parent
9d37c56bab
commit
12f14710ce
|
@ -125,12 +125,10 @@ def compute_metrics(pred):
|
|||
labels_ids = pred.label_ids
|
||||
pred_ids = pred.predictions
|
||||
|
||||
pred_str = tokenizer.batch_decode(pred_ids, clean_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, clean_special_tokens=True)
|
||||
|
||||
pred_str = [pred.split("[CLS]")[-1].split("[SEP]")[0] for pred in pred_str]
|
||||
label_str = [label.split("[CLS]")[-1].split("[SEP]")[0] for label in label_str]
|
||||
|
||||
# all unnecessary tokens are removed
|
||||
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
||||
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
|
||||
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
||||
|
||||
return {
|
||||
|
@ -189,6 +187,43 @@ trainer = Trainer(
|
|||
trainer.train()
|
||||
```
|
||||
|
||||
## Results
|
||||
## Evaluation
|
||||
|
||||
TODO
|
||||
The following script evaluates the model on the test set of
|
||||
CNN/Daily Mail.
|
||||
|
||||
```python
|
||||
#!/usr/bin/env python3
|
||||
import nlp
|
||||
from transformers import BertTokenizer, EncoderDecoderModel
|
||||
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||
model.to("cuda")
|
||||
test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test")
|
||||
batch_size = 128
|
||||
# map data correctly
|
||||
def generate_summary(batch):
|
||||
# Tokenizer will automatically set [BOS] <text> [EOS]
|
||||
# cut off at BERT max length 512
|
||||
inputs = tokenizer(batch["article"], padding="max_length", truncation=True, max_length=512, return_tensors="pt")
|
||||
input_ids = inputs.input_ids.to("cuda")
|
||||
attention_mask = inputs.attention_mask.to("cuda")
|
||||
outputs = model.generate(input_ids, attention_mask=attention_mask)
|
||||
# all special tokens including will be removed
|
||||
output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
batch["pred"] = output_str
|
||||
return batch
|
||||
results = test_dataset.map(generate_summary, batched=True, batch_size=batch_size, remove_columns=["article"])
|
||||
# load rouge for validation
|
||||
rouge = nlp.load_metric("rouge")
|
||||
pred_str = results["pred"]
|
||||
label_str = results["highlights"]
|
||||
rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid
|
||||
print(rouge_output)
|
||||
```
|
||||
|
||||
The obtained results should be:
|
||||
|
||||
| - | Rouge2 - mid -precision | Rouge2 - mid - recall | Rouge2 - mid - fmeasure |
|
||||
|----------|:-------------:|:------:|:------:|
|
||||
| **CNN/Daily Mail** | 14.12 | 14.37 | **13.8** |
|
||||
|
|
Loading…
Reference in New Issue