Added support for generic discriminators
This commit is contained in:
parent
b0eaff36e6
commit
7fd54b55a3
|
@ -14,17 +14,16 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: add code for training a custom discriminator
|
||||
|
||||
"""
|
||||
Example command with bag of words:
|
||||
python examples/run_pplm.py -B space --cond_text "The president" --length 100 --gamma 1.5 --num_iterations 3 --num_samples 10 --stepsize 0.01 --window_length 5 --kl_scale 0.01 --gm_scale 0.95
|
||||
|
||||
Example command with discriminator:
|
||||
python examples/run_pplm.py -D sentiment --label_class 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
||||
python examples/run_pplm.py -D sentiment --class_label 3 --cond_text "The lake" --length 10 --gamma 1.0 --num_iterations 30 --num_samples 10 --stepsize 0.01 --kl_scale 0.01 --gm_scale 0.95
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from operator import add
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
|
@ -121,7 +120,7 @@ def perturb_past(
|
|||
grad_norms=None,
|
||||
stepsize=0.01,
|
||||
classifier=None,
|
||||
label_class=None,
|
||||
class_label=None,
|
||||
one_hot_bows_vectors=None,
|
||||
loss_type=0,
|
||||
num_iterations=3,
|
||||
|
@ -230,7 +229,7 @@ def perturb_past(
|
|||
prediction = classifier(new_accumulated_hidden /
|
||||
(curr_length + 1 + horizon_length))
|
||||
|
||||
label = torch.tensor([label_class], device=device,
|
||||
label = torch.tensor([class_label], device=device,
|
||||
dtype=torch.long)
|
||||
discrim_loss = ce_loss(prediction, label)
|
||||
print(" pplm_discrim_loss:", discrim_loss.data.cpu().numpy())
|
||||
|
@ -244,7 +243,8 @@ def perturb_past(
|
|||
unpert_probs + SMALL_CONST *
|
||||
(unpert_probs <= SMALL_CONST).float().to(device).detach()
|
||||
)
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(device).detach()
|
||||
correction = SMALL_CONST * (probs <= SMALL_CONST).float().to(
|
||||
device).detach()
|
||||
corrected_probs = probs + correction.detach()
|
||||
kl_loss = kl_scale * (
|
||||
(corrected_probs * (corrected_probs / unpert_probs).log()).sum()
|
||||
|
@ -273,7 +273,8 @@ def perturb_past(
|
|||
# normalize gradients
|
||||
grad = [
|
||||
-stepsize *
|
||||
(p_.grad * window_mask / grad_norms[index] ** gamma).data.cpu().numpy()
|
||||
(p_.grad * window_mask / grad_norms[
|
||||
index] ** gamma).data.cpu().numpy()
|
||||
for index, p_ in enumerate(curr_perturbation)
|
||||
]
|
||||
|
||||
|
@ -301,7 +302,7 @@ def perturb_past(
|
|||
|
||||
|
||||
def get_classifier(
|
||||
name: Optional[str], label_class: Union[str, int],
|
||||
name: Optional[str], class_label: Union[str, int],
|
||||
device: str
|
||||
) -> Tuple[Optional[ClassificationHead], Optional[int]]:
|
||||
if name is None:
|
||||
|
@ -312,26 +313,29 @@ def get_classifier(
|
|||
class_size=params['class_size'],
|
||||
embed_size=params['embed_size']
|
||||
).to(device)
|
||||
resolved_archive_file = cached_path(params["url"])
|
||||
if "url" in params:
|
||||
resolved_archive_file = cached_path(params["url"])
|
||||
else:
|
||||
resolved_archive_file = params["path"]
|
||||
classifier.load_state_dict(
|
||||
torch.load(resolved_archive_file, map_location=device))
|
||||
classifier.eval()
|
||||
|
||||
if isinstance(label_class, str):
|
||||
if label_class in params["class_vocab"]:
|
||||
label_id = params["class_vocab"][label_class]
|
||||
if isinstance(class_label, str):
|
||||
if class_label in params["class_vocab"]:
|
||||
label_id = params["class_vocab"][class_label]
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("label_class {} not in class_vocab".format(label_class))
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
elif isinstance(label_class, int):
|
||||
if label_class in set(params["class_vocab"].values()):
|
||||
label_id = label_class
|
||||
elif isinstance(class_label, int):
|
||||
if class_label in set(params["class_vocab"].values()):
|
||||
label_id = class_label
|
||||
else:
|
||||
label_id = params["default_class"]
|
||||
print("label_class {} not in class_vocab".format(label_class))
|
||||
print("class_label {} not in class_vocab".format(class_label))
|
||||
print("available values are: {}".format(params["class_vocab"]))
|
||||
print("using default class {}".format(label_id))
|
||||
|
||||
|
@ -379,7 +383,7 @@ def full_text_generation(
|
|||
device="cuda",
|
||||
sample=True,
|
||||
discrim=None,
|
||||
label_class=None,
|
||||
class_label=None,
|
||||
bag_of_words=None,
|
||||
length=100,
|
||||
grad_length=10000,
|
||||
|
@ -397,7 +401,7 @@ def full_text_generation(
|
|||
):
|
||||
classifier, class_id = get_classifier(
|
||||
discrim,
|
||||
label_class,
|
||||
class_label,
|
||||
device
|
||||
)
|
||||
|
||||
|
@ -443,7 +447,7 @@ def full_text_generation(
|
|||
perturb=True,
|
||||
bow_indices=bow_indices,
|
||||
classifier=classifier,
|
||||
label_class=class_id,
|
||||
class_label=class_id,
|
||||
loss_type=loss_type,
|
||||
length=length,
|
||||
grad_length=grad_length,
|
||||
|
@ -477,7 +481,7 @@ def generate_text_pplm(
|
|||
sample=True,
|
||||
perturb=True,
|
||||
classifier=None,
|
||||
label_class=None,
|
||||
class_label=None,
|
||||
bow_indices=None,
|
||||
loss_type=0,
|
||||
length=100,
|
||||
|
@ -545,7 +549,7 @@ def generate_text_pplm(
|
|||
grad_norms=grad_norms,
|
||||
stepsize=current_stepsize,
|
||||
classifier=classifier,
|
||||
label_class=label_class,
|
||||
class_label=class_label,
|
||||
one_hot_bows_vectors=one_hot_bows_vectors,
|
||||
loss_type=loss_type,
|
||||
num_iterations=num_iterations,
|
||||
|
@ -567,7 +571,7 @@ def generate_text_pplm(
|
|||
if classifier is not None:
|
||||
ce_loss = torch.nn.CrossEntropyLoss()
|
||||
prediction = classifier(torch.mean(unpert_last_hidden, dim=1))
|
||||
label = torch.tensor([label_class], device=device,
|
||||
label = torch.tensor([class_label], device=device,
|
||||
dtype=torch.long)
|
||||
unpert_discrim_loss = ce_loss(prediction, label)
|
||||
print(
|
||||
|
@ -613,6 +617,20 @@ def generate_text_pplm(
|
|||
return output_so_far, unpert_discrim_loss, loss_in_time
|
||||
|
||||
|
||||
def set_generic_model_params(discrim_weights, discrim_meta):
|
||||
if discrim_weights is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_weights need to be specified')
|
||||
if discrim_meta is None:
|
||||
raise ValueError('When using a generic discriminator, '
|
||||
'discrim_meta need to be specified')
|
||||
|
||||
with open(discrim_meta, 'r') as discrim_meta_file:
|
||||
meta = json.load(discrim_meta_file)
|
||||
meta['path'] = discrim_weights
|
||||
DISCRIMINATOR_MODELS_PARAMS['generic'] = meta
|
||||
|
||||
|
||||
def run_model():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
|
@ -636,11 +654,15 @@ def run_model():
|
|||
"-D",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=("clickbait", "sentiment", "toxicity"),
|
||||
help="Discriminator to use for loss-type 2",
|
||||
choices=("clickbait", "sentiment", "toxicity", "generic"),
|
||||
help="Discriminator to use",
|
||||
)
|
||||
parser.add_argument('--discrim_weights', type=str, default=None,
|
||||
help='Weights for the generic discriminator')
|
||||
parser.add_argument('--discrim_meta', type=str, default=None,
|
||||
help='Meta information for the generic discriminator')
|
||||
parser.add_argument(
|
||||
"--label_class",
|
||||
"--class_label",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Class label used for the discriminator",
|
||||
|
@ -697,6 +719,9 @@ def run_model():
|
|||
# set the device
|
||||
device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
|
||||
|
||||
if args.discrim == 'generic':
|
||||
set_generic_model_params(args.discrim_weights, args.discrim_meta)
|
||||
|
||||
# load pretrained model
|
||||
model = GPT2LMHeadModel.from_pretrained(
|
||||
args.model_path,
|
||||
|
|
Loading…
Reference in New Issue