fix the loss backward issue

(cherry picked from commit 566468cc984c6ec7e10dfc62b5b4191781a99cd2)
This commit is contained in:
Rosanne Liu 2019-11-03 04:51:57 +00:00 committed by Julien Chaumond
parent 572c24cfa2
commit 83b1e6ac9e
1 changed files with 6 additions and 5 deletions

View File

@ -36,6 +36,7 @@ from tqdm import trange
from transformers import GPT2Tokenizer
from transformers.file_utils import cached_path
from transformers.modeling_gpt2 import GPT2LMHeadModel
from IPython import embed
PPLM_BOW = 1
PPLM_DISCRIM = 2
@ -246,8 +247,8 @@ def perturb_past(
inputs_embeds=inputs_embeds
)
# get expected hidden states
unpert_hidden = curr_all_hidden[1]
accumulated_hidden += torch.sum(unpert_hidden, dim=1)
unpert_hidden = curr_all_hidden[-1]
accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach()
prediction = classifier(
accumulated_hidden / (curr_length + 1 + horizon_length)
@ -257,7 +258,7 @@ def perturb_past(
discrim_loss += ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
if kl_scale > 0.0:
if kl_scale >= 0.0:
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
unpert_probs = (
unpert_probs + SMALL_CONST *
@ -270,7 +271,7 @@ def perturb_past(
torch.FloatTensor
).cuda().detach()
corrected_probs = probs + correction.detach()
kl_loss += kl_scale * (
kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
)
print(' kl_loss', (kl_loss).data.cpu().numpy())
@ -280,7 +281,7 @@ def perturb_past(
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
# compute gradients
loss.backward(retain_graph=True)
loss.backward()
# calculate gradient norms
if grad_norms is not None and loss_type == PPLM_BOW: