fix the loss backward issue
(cherry picked from commit 566468cc984c6ec7e10dfc62b5b4191781a99cd2)
This commit is contained in:
parent
572c24cfa2
commit
83b1e6ac9e
|
@ -36,6 +36,7 @@ from tqdm import trange
|
|||
from transformers import GPT2Tokenizer
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.modeling_gpt2 import GPT2LMHeadModel
|
||||
from IPython import embed
|
||||
|
||||
PPLM_BOW = 1
|
||||
PPLM_DISCRIM = 2
|
||||
|
@ -246,8 +247,8 @@ def perturb_past(
|
|||
inputs_embeds=inputs_embeds
|
||||
)
|
||||
# get expected hidden states
|
||||
unpert_hidden = curr_all_hidden[1]
|
||||
accumulated_hidden += torch.sum(unpert_hidden, dim=1)
|
||||
unpert_hidden = curr_all_hidden[-1]
|
||||
accumulated_hidden += torch.sum(unpert_hidden, dim=1).detach()
|
||||
|
||||
prediction = classifier(
|
||||
accumulated_hidden / (curr_length + 1 + horizon_length)
|
||||
|
@ -257,7 +258,7 @@ def perturb_past(
|
|||
discrim_loss += ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
|
||||
if kl_scale > 0.0:
|
||||
if kl_scale >= 0.0:
|
||||
unpert_probs = F.softmax(unpert_logits[:, -1, :], dim=-1)
|
||||
unpert_probs = (
|
||||
unpert_probs + SMALL_CONST *
|
||||
|
@ -270,7 +271,7 @@ def perturb_past(
|
|||
torch.FloatTensor
|
||||
).cuda().detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss += kl_scale * (
|
||||
kl_loss = kl_scale * (
|
||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||
)
|
||||
print(' kl_loss', (kl_loss).data.cpu().numpy())
|
||||
|
@ -280,7 +281,7 @@ def perturb_past(
|
|||
print(' pplm_loss', (loss - kl_loss).data.cpu().numpy())
|
||||
|
||||
# compute gradients
|
||||
loss.backward(retain_graph=True)
|
||||
loss.backward()
|
||||
|
||||
# calculate gradient norms
|
||||
if grad_norms is not None and loss_type == PPLM_BOW:
|
||||
|
|
Loading…
Reference in New Issue