Fix test failures due to old torch issue with non-contiguous view

This commit is contained in:
Catalin Voss 2019-03-24 14:37:13 -07:00
parent 0dd796e359
commit fda2f62395
2 changed files with 8 additions and 8 deletions

View File

@ -618,8 +618,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
@ -698,8 +698,8 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))

View File

@ -717,8 +717,8 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
lm_logits = self.lm_head(hidden_states)
if lm_labels is not None:
# Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
# In tensorflow, it's [batch, d_0, d_1, ..., d_{r-1}, num_classes]
# in pytorch, it's [batch, num_classes, d_0, d_1, ..., d_{r-1}]
@ -811,8 +811,8 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = []
if lm_labels is not None:
shift_logits = lm_logits[:, :-1]
shift_labels = lm_labels[:, 1:]
shift_logits = lm_logits[:, :-1].contiguous()
shift_labels = lm_labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1,
shift_logits.size(-1)), shift_labels.view(-1)))