Adding the prepare_seq2seq_batch function to ProphetNet (#8515)

* Simply insert T5Tokenizer's prepare_seq2seq_batch

* Update/Add some 'import'

* fix RunTimeError caused by '.view'

* Moves .view related error avoidance from seq2seq_trainer to inside prophetnet

* Update test_tokenization_prophetnet.py

* Format the test code with black

* Re-format the test code

* Update test_tokenization_prophetnet.py

* Add importing require_torch in the test code

* Add importing BatchEncoding in the test code

* Re-format the test code on Colab
This commit is contained in:
Yusuke Mori 2020-11-16 22:18:25 +09:00 committed by GitHub
parent 931b10978e
commit 04d8136bde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 71 additions and 2 deletions

View File

@ -1766,6 +1766,10 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
logits = predict_logits[:, 0]
logits_ngram = predict_logits[:, 1:] if self.config.ngram > 1 else None
# To use .view in loss computation, make sure that logits is contiguous.
if not logits.is_contiguous():
logits = logits.contiguous()
loss = None
if labels is not None:
loss = self._compute_loss(predict_logits, labels)

View File

@ -17,8 +17,10 @@ import collections
import os
from typing import List, Optional, Tuple
from .file_utils import add_start_docstrings
from .tokenization_bert import BasicTokenizer, WordpieceTokenizer
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer
from .tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING
from .utils import logging
@ -286,3 +288,43 @@ class ProphetNetTokenizer(PreTrainedTokenizer):
return token_ids_0 + [self.sep_token_id]
sep = [self.sep_token_id]
return token_ids_0 + sep + token_ids_1 + sep
@add_start_docstrings(PREPARE_SEQ2SEQ_BATCH_DOCSTRING)
def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
if max_length is None:
max_length = self.max_len
model_inputs = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
labels_and_decoder_mask = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
model_inputs["labels"] = labels_and_decoder_mask["input_ids"]
return model_inputs

View File

@ -17,7 +17,8 @@
import os
import unittest
from transformers.testing_utils import slow
from transformers import BatchEncoding
from transformers.testing_utils import require_torch, slow
from transformers.tokenization_bert import (
BasicTokenizer,
WordpieceTokenizer,
@ -150,6 +151,28 @@ class ProphetNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokenizer.tokenize("unwantedX running"), ["[UNK]", "runn", "##ing"])
@require_torch
def test_prepare_seq2seq_batch(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/prophetnet-large-uncased")
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [1037, 2146, 20423, 2005, 7680, 7849, 3989, 1012, 102]
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors="pt",
)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
def test_is_whitespace(self):
self.assertTrue(_is_whitespace(" "))
self.assertTrue(_is_whitespace("\t"))