[examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994)

This commit is contained in:
Suraj Patil 2020-07-27 19:40:43 +05:30 committed by GitHub
parent 3deffc1d67
commit d1d15d6f2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 3 deletions

View File

@ -14,7 +14,7 @@ import torch
from torch.utils.data import DataLoader
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
try:
@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
def _step(self, batch: dict) -> Tuple:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
if isinstance(self.model, T5ForConditionalGeneration):
decoder_input_ids = self.model._shift_right(target_ids)
lm_labels = target_ids
else:
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
lm_labels = target_ids[:, 1:].clone() # why clone?
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
if self.hparams.label_smoothing == 0: