[s2s] distillation apex breaks return_dict obj (#8631)

* apex breaks return_dict obj

* style
This commit is contained in:
Stas Bekman 2020-11-18 12:51:29 -08:00 committed by GitHub
parent bf3611b2ab
commit d86d57faa3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 8 deletions

View File

@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule):
output_attentions=False,
use_cache=False,
)
lm_logits = student_outputs.logits
lm_logits = student_outputs["logits"]
# Same cross entropy vs. label smoothing logic as finetune.py
assert lm_logits.shape[-1] == self.model.config.vocab_size
@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule):
def zero_tensor():
return torch.tensor(0.0).type_as(student_lm_loss)
teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
teacher_enc_outputs = student_outputs[
"encoder_last_hidden_state"
] # use this unless self.different_base_models
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
if self.different_encoder: # compute encoder hidden state loss
all_teacher_encoder_outputs = self.teacher.get_encoder()(
@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule):
output_hidden_states=self.do_calc_hidden_loss,
)
if self.different_base_models:
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
elif self.do_calc_hidden_loss:
hid_loss_enc = self.calc_hidden_loss(
src_mask,
student_outputs.encoder_hidden_states,
all_teacher_encoder_outputs.hidden_states,
student_outputs["encoder_hidden_states"],
all_teacher_encoder_outputs["hidden_states"],
self.e_matches,
normalize_hidden=self.hparams.normalize_hidden,
)
@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule):
use_cache=False, # since we are not passing labels, never let this default to True
)
dec_mask = decoder_input_ids.ne(pad_token_id)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
hid_loss_dec = self.calc_hidden_loss(
dec_mask,
student_outputs.decoder_hidden_states,
teacher_outputs.decoder_hidden_states,
student_outputs["decoder_hidden_states"],
teacher_outputs["decoder_hidden_states"],
self.d_matches,
normalize_hidden=self.hparams.normalize_hidden,
)