run_pplm.py bug fix (#4867)

`is_leaf` may become `False` after `.to(device=device)` function call.
This commit is contained in:
songyouwei 2020-06-10 07:14:27 +08:00 committed by GitHub
parent 13aa174112
commit 29c36e9f36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -148,6 +148,9 @@ def perturb_past(
for i in range(num_iterations):
print("Iteration ", i + 1)
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
# make sure p_.grad is not None
for p_ in curr_perturbation:
p_.retain_grad()
# Compute hidden using perturbed past
perturbed_past = list(map(add, past, curr_perturbation))