Changed order of some parameters to be more consistent. Identical results.
This commit is contained in:
parent
f42816e7fc
commit
893d0d64fe
|
@ -121,17 +121,17 @@ def perturb_past(
|
|||
accumulated_hidden=None,
|
||||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
one_hot_bows_vectors=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
one_hot_bows_vectors=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
kl_scale=0.01,
|
||||
window_length=0,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
device='cuda'
|
||||
kl_scale=0.01,
|
||||
device='cuda',
|
||||
):
|
||||
# Generate inital perturbed past
|
||||
grad_accumulator = [
|
||||
|
@ -351,8 +351,7 @@ def get_classifier(
|
|||
|
||||
|
||||
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
|
||||
List[
|
||||
List[List[int]]]:
|
||||
List[List[List[int]]]:
|
||||
bow_indices = []
|
||||
for id_or_path in bag_of_words_ids_or_paths:
|
||||
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
|
||||
|
@ -388,22 +387,22 @@ def full_text_generation(
|
|||
context=None,
|
||||
num_samples=1,
|
||||
device="cuda",
|
||||
sample=False,
|
||||
bag_of_words=None,
|
||||
discrim=None,
|
||||
class_label=None,
|
||||
bag_of_words=None,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
**kwargs
|
||||
):
|
||||
classifier, class_id = get_classifier(
|
||||
|
@ -454,24 +453,24 @@ def full_text_generation(
|
|||
tokenizer=tokenizer,
|
||||
context=context,
|
||||
device=device,
|
||||
sample=sample,
|
||||
perturb=True,
|
||||
bow_indices=bow_indices,
|
||||
classifier=classifier,
|
||||
class_label=class_id,
|
||||
loss_type=loss_type,
|
||||
length=length,
|
||||
grad_length=grad_length,
|
||||
stepsize=stepsize,
|
||||
num_iterations=num_iterations,
|
||||
temperature=temperature,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
top_k=top_k,
|
||||
window_length=window_length,
|
||||
sample=sample,
|
||||
num_iterations=num_iterations,
|
||||
grad_length=grad_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
gm_scale=gm_scale,
|
||||
kl_scale=kl_scale,
|
||||
)
|
||||
pert_gen_tok_texts.append(pert_gen_tok_text)
|
||||
if classifier is not None:
|
||||
|
@ -490,24 +489,24 @@ def generate_text_pplm(
|
|||
context=None,
|
||||
past=None,
|
||||
device="cuda",
|
||||
sample=False,
|
||||
perturb=True,
|
||||
bow_indices=None,
|
||||
classifier=None,
|
||||
class_label=None,
|
||||
bow_indices=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
stepsize=0.02,
|
||||
num_iterations=3,
|
||||
temperature=1.0,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
top_k=10,
|
||||
window_length=0,
|
||||
sample=False,
|
||||
num_iterations=3,
|
||||
grad_length=10000,
|
||||
horizon_length=1,
|
||||
window_length=0,
|
||||
decay=False,
|
||||
gamma=1.5,
|
||||
gm_scale=0.9,
|
||||
kl_scale=0.01,
|
||||
):
|
||||
output_so_far = (
|
||||
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
|
||||
|
@ -561,17 +560,17 @@ def generate_text_pplm(
|
|||
accumulated_hidden=accumulated_hidden,
|
||||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
classifier=classifier,
|
||||
class_label=class_label,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
kl_scale=kl_scale,
|
||||
window_length=window_length,
|
||||
horizon_length=horizon_length,
|
||||
window_length=window_length,
|
||||
decay=decay,
|
||||
gamma=gamma,
|
||||
device=device
|
||||
kl_scale=kl_scale,
|
||||
device=device,
|
||||
)
|
||||
loss_in_time.append(loss_this_iter)
|
||||
else:
|
||||
|
@ -685,7 +684,7 @@ def run_pplm_example(
|
|||
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
|
||||
"pretrained_model"
|
||||
]
|
||||
print("discrim = {}, setting pretrained_model "
|
||||
print("discrim = {}, pretrained_model set "
|
||||
"to discriminator's = {}".format(discrim, pretrained_model))
|
||||
|
||||
# load pretrained model
|
||||
|
@ -810,6 +809,20 @@ if __name__ == '__main__':
|
|||
default="gpt2-medium",
|
||||
help="pretrained model name or path to local checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_text", type=str, default="The lake",
|
||||
help="Prefix texts to condition on"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uncond", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bag_of_words",
|
||||
"-B",
|
||||
|
@ -837,40 +850,16 @@ if __name__ == '__main__':
|
|||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--length", type=int, default=100)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--stepsize", type=float, default=0.02)
|
||||
parser.add_argument("--temperature", type=float, default=1.0)
|
||||
parser.add_argument("--top_k", type=int, default=10)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument(
|
||||
"--sample", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--uncond", action="store_true",
|
||||
help="Generate from end-of-text as prefix"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_text", type=str, default="The lake",
|
||||
help="Prefix texts to condition on"
|
||||
)
|
||||
parser.add_argument("--num_iterations", type=int, default=3)
|
||||
parser.add_argument("--grad_length", type=int, default=10000)
|
||||
parser.add_argument(
|
||||
"--num_samples",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of samples to generate from the modified latents",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--window_length",
|
||||
type=int,
|
||||
|
@ -878,9 +867,19 @@ if __name__ == '__main__':
|
|||
help="Length of past which is being optimized; "
|
||||
"0 corresponds to infinite window length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--horizon_length",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Length of future to optimize over",
|
||||
)
|
||||
parser.add_argument("--decay", action="store_true",
|
||||
help="whether to decay or not")
|
||||
parser.add_argument("--gamma", type=float, default=1.5)
|
||||
parser.add_argument("--gm_scale", type=float, default=0.9)
|
||||
parser.add_argument("--kl_scale", type=float, default=0.01)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
|
||||
parser.add_argument("--colorama", action="store_true",
|
||||
help="colors keywords")
|
||||
|
||||
|
|
Loading…
Reference in New Issue