[s2s] distillation apex breaks return_dict obj (#8631)
* apex breaks return_dict obj * style
This commit is contained in:
parent
bf3611b2ab
commit
d86d57faa3
|
@ -154,7 +154,7 @@ class SummarizationDistiller(SummarizationModule):
|
|||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
lm_logits = student_outputs.logits
|
||||
lm_logits = student_outputs["logits"]
|
||||
|
||||
# Same cross entropy vs. label smoothing logic as finetune.py
|
||||
assert lm_logits.shape[-1] == self.model.config.vocab_size
|
||||
|
@ -171,7 +171,9 @@ class SummarizationDistiller(SummarizationModule):
|
|||
def zero_tensor():
|
||||
return torch.tensor(0.0).type_as(student_lm_loss)
|
||||
|
||||
teacher_enc_outputs = student_outputs.encoder_last_hidden_state # use this unless self.different_base_models
|
||||
teacher_enc_outputs = student_outputs[
|
||||
"encoder_last_hidden_state"
|
||||
] # use this unless self.different_base_models
|
||||
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
|
||||
if self.different_encoder: # compute encoder hidden state loss
|
||||
all_teacher_encoder_outputs = self.teacher.get_encoder()(
|
||||
|
@ -180,12 +182,12 @@ class SummarizationDistiller(SummarizationModule):
|
|||
output_hidden_states=self.do_calc_hidden_loss,
|
||||
)
|
||||
if self.different_base_models:
|
||||
teacher_enc_outputs = all_teacher_encoder_outputs.last_hidden_state
|
||||
teacher_enc_outputs = all_teacher_encoder_outputs["last_hidden_state"]
|
||||
elif self.do_calc_hidden_loss:
|
||||
hid_loss_enc = self.calc_hidden_loss(
|
||||
src_mask,
|
||||
student_outputs.encoder_hidden_states,
|
||||
all_teacher_encoder_outputs.hidden_states,
|
||||
student_outputs["encoder_hidden_states"],
|
||||
all_teacher_encoder_outputs["hidden_states"],
|
||||
self.e_matches,
|
||||
normalize_hidden=self.hparams.normalize_hidden,
|
||||
)
|
||||
|
@ -199,12 +201,12 @@ class SummarizationDistiller(SummarizationModule):
|
|||
use_cache=False, # since we are not passing labels, never let this default to True
|
||||
)
|
||||
dec_mask = decoder_input_ids.ne(pad_token_id)
|
||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs.logits)
|
||||
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, teacher_outputs["logits"])
|
||||
if self.do_calc_hidden_loss: # Intermediate supervision of decoder hidden states
|
||||
hid_loss_dec = self.calc_hidden_loss(
|
||||
dec_mask,
|
||||
student_outputs.decoder_hidden_states,
|
||||
teacher_outputs.decoder_hidden_states,
|
||||
student_outputs["decoder_hidden_states"],
|
||||
teacher_outputs["decoder_hidden_states"],
|
||||
self.d_matches,
|
||||
normalize_hidden=self.hparams.normalize_hidden,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue