Changed order of some parameters to be more consistent. Identical results.

This commit is contained in:
w4nderlust 2019-11-29 20:19:33 -08:00 committed by Julien Chaumond
parent f42816e7fc
commit 893d0d64fe
1 changed files with 55 additions and 56 deletions

View File

@ -121,17 +121,17 @@ def perturb_past(
accumulated_hidden=None,
grad_norms=None,
stepsize=0.01,
one_hot_bows_vectors=None,
classifier=None,
class_label=None,
one_hot_bows_vectors=None,
loss_type=0,
num_iterations=3,
kl_scale=0.01,
window_length=0,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
device='cuda'
kl_scale=0.01,
device='cuda',
):
# Generate inital perturbed past
grad_accumulator = [
@ -351,8 +351,7 @@ def get_classifier(
def get_bag_of_words_indices(bag_of_words_ids_or_paths: List[str], tokenizer) -> \
List[
List[List[int]]]:
List[List[List[int]]]:
bow_indices = []
for id_or_path in bag_of_words_ids_or_paths:
if id_or_path in BAG_OF_WORDS_ARCHIVE_MAP:
@ -388,22 +387,22 @@ def full_text_generation(
context=None,
num_samples=1,
device="cuda",
sample=False,
bag_of_words=None,
discrim=None,
class_label=None,
bag_of_words=None,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
**kwargs
):
classifier, class_id = get_classifier(
@ -454,24 +453,24 @@ def full_text_generation(
tokenizer=tokenizer,
context=context,
device=device,
sample=sample,
perturb=True,
bow_indices=bow_indices,
classifier=classifier,
class_label=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
stepsize=stepsize,
num_iterations=num_iterations,
temperature=temperature,
gm_scale=gm_scale,
kl_scale=kl_scale,
top_k=top_k,
window_length=window_length,
sample=sample,
num_iterations=num_iterations,
grad_length=grad_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
gm_scale=gm_scale,
kl_scale=kl_scale,
)
pert_gen_tok_texts.append(pert_gen_tok_text)
if classifier is not None:
@ -490,24 +489,24 @@ def generate_text_pplm(
context=None,
past=None,
device="cuda",
sample=False,
perturb=True,
bow_indices=None,
classifier=None,
class_label=None,
bow_indices=None,
loss_type=0,
length=100,
grad_length=10000,
stepsize=0.02,
num_iterations=3,
temperature=1.0,
gm_scale=0.9,
kl_scale=0.01,
top_k=10,
window_length=0,
sample=False,
num_iterations=3,
grad_length=10000,
horizon_length=1,
window_length=0,
decay=False,
gamma=1.5,
gm_scale=0.9,
kl_scale=0.01,
):
output_so_far = (
torch.tensor(context, device=device, dtype=torch.long).unsqueeze(0)
@ -561,17 +560,17 @@ def generate_text_pplm(
accumulated_hidden=accumulated_hidden,
grad_norms=grad_norms,
stepsize=current_stepsize,
one_hot_bows_vectors=one_hot_bows_vectors,
classifier=classifier,
class_label=class_label,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
kl_scale=kl_scale,
window_length=window_length,
horizon_length=horizon_length,
window_length=window_length,
decay=decay,
gamma=gamma,
device=device
kl_scale=kl_scale,
device=device,
)
loss_in_time.append(loss_this_iter)
else:
@ -685,7 +684,7 @@ def run_pplm_example(
pretrained_model = DISCRIMINATOR_MODELS_PARAMS[discrim][
"pretrained_model"
]
print("discrim = {}, setting pretrained_model "
print("discrim = {}, pretrained_model set "
"to discriminator's = {}".format(discrim, pretrained_model))
# load pretrained model
@ -810,6 +809,20 @@ if __name__ == '__main__':
default="gpt2-medium",
help="pretrained model name or path to local checkpoint",
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--bag_of_words",
"-B",
@ -837,40 +850,16 @@ if __name__ == '__main__':
default=-1,
help="Class label used for the discriminator",
)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--length", type=int, default=100)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--stepsize", type=float, default=0.02)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_k", type=int, default=10)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument(
"--sample", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--uncond", action="store_true",
help="Generate from end-of-text as prefix"
)
parser.add_argument(
"--cond_text", type=str, default="The lake",
help="Prefix texts to condition on"
)
parser.add_argument("--num_iterations", type=int, default=3)
parser.add_argument("--grad_length", type=int, default=10000)
parser.add_argument(
"--num_samples",
type=int,
default=1,
help="Number of samples to generate from the modified latents",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument(
"--window_length",
type=int,
@ -878,9 +867,19 @@ if __name__ == '__main__':
help="Length of past which is being optimized; "
"0 corresponds to infinite window length",
)
parser.add_argument(
"--horizon_length",
type=int,
default=1,
help="Length of future to optimize over",
)
parser.add_argument("--decay", action="store_true",
help="whether to decay or not")
parser.add_argument("--gamma", type=float, default=1.5)
parser.add_argument("--gm_scale", type=float, default=0.9)
parser.add_argument("--kl_scale", type=float, default=0.01)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--no_cuda", action="store_true", help="no cuda")
parser.add_argument("--colorama", action="store_true",
help="colors keywords")