run_pplm.py bug fix (#4867)
`is_leaf` may become `False` after `.to(device=device)` function call.
This commit is contained in:
parent
13aa174112
commit
29c36e9f36
|
@ -148,6 +148,9 @@ def perturb_past(
|
|||
for i in range(num_iterations):
|
||||
print("Iteration ", i + 1)
|
||||
curr_perturbation = [torch.from_numpy(p_).requires_grad_(True).to(device=device) for p_ in grad_accumulator]
|
||||
# make sure p_.grad is not None
|
||||
for p_ in curr_perturbation:
|
||||
p_.retain_grad()
|
||||
|
||||
# Compute hidden using perturbed past
|
||||
perturbed_past = list(map(add, past, curr_perturbation))
|
||||
|
|
Loading…
Reference in New Issue