Fix typo syntax err (sorry, c/p from my repo)

This commit is contained in:
Catalin Voss 2019-03-24 13:49:42 -07:00
parent 2e6f5ffb96
commit 472857c47f
2 changed files with 2 additions and 2 deletions

View File

@ -625,7 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
return lm_logits, presents

View File

@ -724,7 +724,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
# We just flatten the tokens out this way.
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
return loss
return lm_logits