Fix test failures due to old torch issue with non-contiguous view
This commit is contained in:
parent
0dd796e359
commit
fda2f62395
|
@ -618,8 +618,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
|||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[:, :-1]
|
||||
shift_labels = lm_labels[:, 1:]
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
|
||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||
|
@ -698,8 +698,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
|||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[:, :-1]
|
||||
shift_labels = lm_labels[:, 1:]
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(shift_logits.view(-1,
|
||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
|
|
|
@ -717,8 +717,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
|||
lm_logits = self.lm_head(hidden_states)
|
||||
if lm_labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = lm_logits[:, :-1]
|
||||
shift_labels = lm_labels[:, 1:]
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
|
||||
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
|
||||
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
|
||||
|
@ -811,8 +811,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
|||
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
|
||||
losses = []
|
||||
if lm_labels is not None:
|
||||
shift_logits = lm_logits[:, :-1]
|
||||
shift_labels = lm_labels[:, 1:]
|
||||
shift_logits = lm_logits[:, :-1].contiguous()
|
||||
shift_labels = lm_labels[:, 1:].contiguous()
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
losses.append(loss_fct(shift_logits.view(-1,
|
||||
shift_logits.size(-1)), shift_labels.view(-1)))
|
||||
|
|
Loading…
Reference in New Issue