Added support for generic discriminators

This commit is contained in:
piero 2019-11-27 21:45:19 -08:00 committed by Julien Chaumond
parent b0eaff36e6
commit 7fd54b55a3
1 changed files with 51 additions and 26 deletions

View File

@ -14,17 +14,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: add code for training a custom discriminator
"""
Example command with bag of words:
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
Example command with discriminator:
python examples/run_pplm.py -D sentiment --label_class 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
"""
import argparse
import json
from operator import add
from typing import List, Optional, Tuple, Union
@ -121,7 +120,7 @@ def perturb_past(
grad_norms=None,
stepsize=0.01,
classifier=None,
label_class=None,
class_label=None,
one_hot_bows_vectors=None,
loss_type=0,
num_iterations=3,
@ -230,7 +229,7 @@ def perturb_past(
prediction = classifier(new_accumulated_hidden /
(curr_length + 1 + horizon_length))
label = torch.tensor([label_class], device=device,
label = torch.tensor([class_label], device=device,
dtype=torch.long)
discrim_loss = ce_loss(prediction, label)
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
@ -244,7 +243,8 @@ def perturb_past(
unpert_probs + SMALL_CONST *
(unpert_probs <= SMALL_CONST).float().to(device).detach()
)
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
device).detach()
corrected_probs = probs + correction.detach()
kl_loss = kl_scale * (
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
@ -273,7 +273,8 @@ def perturb_past(
# normalize gradients
grad = [
-stepsize *
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
(p_.grad * window_mask / grad_norms[
index] ** gamma).data.cpu().numpy()
for index, p_ in enumerate(curr_perturbation)
]
@ -301,7 +302,7 @@ def perturb_past(
def get_classifier(
name: Optional[str], label_class: Union[str, int],
name: Optional[str], class_label: Union[str, int],
device: str
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
if name is None:
@ -312,26 +313,29 @@ def get_classifier(
class_size=params['class_size'],
embed_size=params['embed_size']
).to(device)
resolved_archive_file = cached_path(params["url"])
if "url" in params:
resolved_archive_file = cached_path(params["url"])
else:
resolved_archive_file = params["path"]
classifier.load_state_dict(
torch.load(resolved_archive_file, map_location=device))
classifier.eval()
if isinstance(label_class, str):
if label_class in params["class_vocab"]:
label_id = params["class_vocab"][label_class]
if isinstance(class_label, str):
if class_label in params["class_vocab"]:
label_id = params["class_vocab"][class_label]
else:
label_id = params["default_class"]
print("label_class {} not in class_vocab".format(label_class))
print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id))
elif isinstance(label_class, int):
if label_class in set(params["class_vocab"].values()):
label_id = label_class
elif isinstance(class_label, int):
if class_label in set(params["class_vocab"].values()):
label_id = class_label
else:
label_id = params["default_class"]
print("label_class {} not in class_vocab".format(label_class))
print("class_label {} not in class_vocab".format(class_label))
print("available values are: {}".format(params["class_vocab"]))
print("using default class {}".format(label_id))
@ -379,7 +383,7 @@ def full_text_generation(
device="cuda",
sample=True,
discrim=None,
label_class=None,
class_label=None,
bag_of_words=None,
length=100,
grad_length=10000,
@ -397,7 +401,7 @@ def full_text_generation(
):
classifier, class_id = get_classifier(
discrim,
label_class,
class_label,
device
)
@ -443,7 +447,7 @@ def full_text_generation(
perturb=True,
bow_indices=bow_indices,
classifier=classifier,
label_class=class_id,
class_label=class_id,
loss_type=loss_type,
length=length,
grad_length=grad_length,
@ -477,7 +481,7 @@ def generate_text_pplm(
sample=True,
perturb=True,
classifier=None,
label_class=None,
class_label=None,
bow_indices=None,
loss_type=0,
length=100,
@ -545,7 +549,7 @@ def generate_text_pplm(
grad_norms=grad_norms,
stepsize=current_stepsize,
classifier=classifier,
label_class=label_class,
class_label=class_label,
one_hot_bows_vectors=one_hot_bows_vectors,
loss_type=loss_type,
num_iterations=num_iterations,
@ -567,7 +571,7 @@ def generate_text_pplm(
if classifier is not None:
ce_loss = torch.nn.CrossEntropyLoss()
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
label = torch.tensor([label_class], device=device,
label = torch.tensor([class_label], device=device,
dtype=torch.long)
unpert_discrim_loss = ce_loss(prediction, label)
print(
@ -613,6 +617,20 @@ def generate_text_pplm(
return output_so_far, unpert_discrim_loss, loss_in_time
def set_generic_model_params(discrim_weights, discrim_meta):
if discrim_weights is None:
raise ValueError('When using a generic discriminator, '
'discrim_weights need to be specified')
if discrim_meta is None:
raise ValueError('When using a generic discriminator, '
'discrim_meta need to be specified')
with open(discrim_meta, 'r') as discrim_meta_file:
meta = json.load(discrim_meta_file)
meta['path'] = discrim_weights
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
def run_model():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -636,11 +654,15 @@ def run_model():
"-D",
type=str,
default=None,
choices=("clickbait", "sentiment", "toxicity"),
help="Discriminator to use for loss-type 2",
choices=("clickbait", "sentiment", "toxicity", "generic"),
help="Discriminator to use",
)
parser.add_argument('--discrim_weights', type=str, default=None,
help='Weights for the generic discriminator')
parser.add_argument('--discrim_meta', type=str, default=None,
help='Meta information for the generic discriminator')
parser.add_argument(
"--label_class",
"--class_label",
type=int,
default=-1,
help="Class label used for the discriminator",
@ -697,6 +719,9 @@ def run_model():
# set the device
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
if args.discrim == 'generic':
set_generic_model_params(args.discrim_weights, args.discrim_meta)
# load pretrained model
model = GPT2LMHeadModel.from_pretrained(
args.model_path,