From 7fd54b55a3f7c3134f8cc5a62f4cc447a5cd34de Mon Sep 17 00:00:00 2001 From: piero Date: Wed, 27 Nov 2019 21:45:19 -0800 Subject: [PATCH] Added support for generic discriminators --- examples/run_pplm.py | 77 +++++++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 26 deletions(-) diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 0d6b0d635d..28aa66cc7d 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -14,17 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -# TODO: add code for training a custom discriminator - """ Example command with bag of words: python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95 Example command with discriminator: -python examples/run_pplm.py -D sentiment --label_class 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 +python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95 """ import argparse +import json from operator import add from typing import List, Optional, Tuple, Union @@ -121,7 +120,7 @@ def perturb_past( grad_norms=None, stepsize=0.01, classifier=None, - label_class=None, + class_label=None, one_hot_bows_vectors=None, loss_type=0, num_iterations=3, @@ -230,7 +229,7 @@ def perturb_past( prediction = classifier(new_accumulated_hidden / (curr_length + 1 + horizon_length)) - label = torch.tensor([label_class], device=device, + label = torch.tensor([class_label], device=device, dtype=torch.long) discrim_loss = ce_loss(prediction, label) print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy()) @@ -244,7 +243,8 @@ def perturb_past( unpert_probs + SMALL_CONST * (unpert_probs <= SMALL_CONST).float().to(device).detach() ) - correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach() + correction = SMALL_CONST * (probs <= SMALL_CONST).float().to( + device).detach() corrected_probs = probs + correction.detach() kl_loss = kl_scale * ( (corrected_probs * (corrected_probs / unpert_probs).log()).sum() @@ -273,7 +273,8 @@ def perturb_past( # normalize gradients grad = [ -stepsize * - (p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy() + (p_.grad * window_mask / grad_norms[ + index] ** gamma).data.cpu().numpy() for index, p_ in enumerate(curr_perturbation) ] @@ -301,7 +302,7 @@ def perturb_past( def get_classifier( - name: Optional[str], label_class: Union[str, int], + name: Optional[str], class_label: Union[str, int], device: str ) -> Tuple[Optional[ClassificationHead], Optional[int]]: if name is None: @@ -312,26 +313,29 @@ def get_classifier( class_size=params['class_size'], embed_size=params['embed_size'] ).to(device) - resolved_archive_file = cached_path(params["url"]) + if "url" in params: + resolved_archive_file = cached_path(params["url"]) + else: + resolved_archive_file = params["path"] classifier.load_state_dict( torch.load(resolved_archive_file, map_location=device)) classifier.eval() - if isinstance(label_class, str): - if label_class in params["class_vocab"]: - label_id = params["class_vocab"][label_class] + if isinstance(class_label, str): + if class_label in params["class_vocab"]: + label_id = params["class_vocab"][class_label] else: label_id = params["default_class"] - print("label_class {} not in class_vocab".format(label_class)) + print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) - elif isinstance(label_class, int): - if label_class in set(params["class_vocab"].values()): - label_id = label_class + elif isinstance(class_label, int): + if class_label in set(params["class_vocab"].values()): + label_id = class_label else: label_id = params["default_class"] - print("label_class {} not in class_vocab".format(label_class)) + print("class_label {} not in class_vocab".format(class_label)) print("available values are: {}".format(params["class_vocab"])) print("using default class {}".format(label_id)) @@ -379,7 +383,7 @@ def full_text_generation( device="cuda", sample=True, discrim=None, - label_class=None, + class_label=None, bag_of_words=None, length=100, grad_length=10000, @@ -397,7 +401,7 @@ def full_text_generation( ): classifier, class_id = get_classifier( discrim, - label_class, + class_label, device ) @@ -443,7 +447,7 @@ def full_text_generation( perturb=True, bow_indices=bow_indices, classifier=classifier, - label_class=class_id, + class_label=class_id, loss_type=loss_type, length=length, grad_length=grad_length, @@ -477,7 +481,7 @@ def generate_text_pplm( sample=True, perturb=True, classifier=None, - label_class=None, + class_label=None, bow_indices=None, loss_type=0, length=100, @@ -545,7 +549,7 @@ def generate_text_pplm( grad_norms=grad_norms, stepsize=current_stepsize, classifier=classifier, - label_class=label_class, + class_label=class_label, one_hot_bows_vectors=one_hot_bows_vectors, loss_type=loss_type, num_iterations=num_iterations, @@ -567,7 +571,7 @@ def generate_text_pplm( if classifier is not None: ce_loss = torch.nn.CrossEntropyLoss() prediction = classifier(torch.mean(unpert_last_hidden, dim=1)) - label = torch.tensor([label_class], device=device, + label = torch.tensor([class_label], device=device, dtype=torch.long) unpert_discrim_loss = ce_loss(prediction, label) print( @@ -613,6 +617,20 @@ def generate_text_pplm( return output_so_far, unpert_discrim_loss, loss_in_time +def set_generic_model_params(discrim_weights, discrim_meta): + if discrim_weights is None: + raise ValueError('When using a generic discriminator, ' + 'discrim_weights need to be specified') + if discrim_meta is None: + raise ValueError('When using a generic discriminator, ' + 'discrim_meta need to be specified') + + with open(discrim_meta, 'r') as discrim_meta_file: + meta = json.load(discrim_meta_file) + meta['path'] = discrim_weights + DISCRIMINATOR_MODELS_PARAMS['generic'] = meta + + def run_model(): parser = argparse.ArgumentParser() parser.add_argument( @@ -636,11 +654,15 @@ def run_model(): "-D", type=str, default=None, - choices=("clickbait", "sentiment", "toxicity"), - help="Discriminator to use for loss-type 2", + choices=("clickbait", "sentiment", "toxicity", "generic"), + help="Discriminator to use", ) + parser.add_argument('--discrim_weights', type=str, default=None, + help='Weights for the generic discriminator') + parser.add_argument('--discrim_meta', type=str, default=None, + help='Meta information for the generic discriminator') parser.add_argument( - "--label_class", + "--class_label", type=int, default=-1, help="Class label used for the discriminator", @@ -697,6 +719,9 @@ def run_model(): # set the device device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + if args.discrim == 'generic': + set_generic_model_params(args.discrim_weights, args.discrim_meta) + # load pretrained model model = GPT2LMHeadModel.from_pretrained( args.model_path,