[examples (seq2seq)] fix preparing decoder_input_ids for T5 (#5994)
This commit is contained in:
parent
3deffc1d67
commit
d1d15d6f2d
|
@ -14,7 +14,7 @@ import torch
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MBartTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration, get_linear_schedule_with_warmup
|
||||
|
||||
|
||||
try:
|
||||
|
@ -131,8 +131,14 @@ class SummarizationModule(BaseTransformer):
|
|||
def _step(self, batch: dict) -> Tuple:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, target_ids = batch["input_ids"], batch["attention_mask"], batch["decoder_input_ids"]
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
|
||||
if isinstance(self.model, T5ForConditionalGeneration):
|
||||
decoder_input_ids = self.model._shift_right(target_ids)
|
||||
lm_labels = target_ids
|
||||
else:
|
||||
decoder_input_ids = target_ids[:, :-1].contiguous() # Why this line?
|
||||
lm_labels = target_ids[:, 1:].clone() # why clone?
|
||||
|
||||
outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
|
||||
|
||||
if self.hparams.label_smoothing == 0:
|
||||
|
|
Loading…
Reference in New Issue