[Seq2Seq] Correct import in Seq2Seq Trainer (#8254)

This commit is contained in:
Patrick von Platen 2020-11-03 13:56:41 +01:00 committed by GitHub
parent 504ff7bb12
commit 9f1747f999
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 4 deletions

View File

@ -62,10 +62,7 @@ class Seq2SeqTrainer(Trainer):
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
else:
# dynamically import label_smoothed_nll_loss
try:
from .utils import label_smoothed_nll_loss
except ImportError:
from utils import label_smoothed_nll_loss
from utils import label_smoothed_nll_loss
self.loss_fn = label_smoothed_nll_loss