From 0b51fba20bd88c4cc4acbb3e9dce82980719895c Mon Sep 17 00:00:00 2001 From: piero Date: Tue, 26 Nov 2019 13:15:56 -0800 Subject: [PATCH] Added script for training a discriminator for pplm to use --- examples/run_pplm.py | 19 +- examples/run_pplm_discrim_train.py | 582 +++++++++++++++++++++++++++++ 2 files changed, 583 insertions(+), 18 deletions(-) create mode 100644 examples/run_pplm_discrim_train.py diff --git a/examples/run_pplm.py b/examples/run_pplm.py index 2f853d15c1..217c131b8f 100644 --- a/examples/run_pplm.py +++ b/examples/run_pplm.py @@ -34,6 +34,7 @@ import torch.nn.functional as F from torch.autograd import Variable from tqdm import trange +from examples.run_pplm_discrim_train import ClassificationHead from transformers import GPT2Tokenizer from transformers.file_utils import cached_path from transformers.modeling_gpt2 import GPT2LMHeadModel @@ -108,24 +109,6 @@ def top_k_filter(logits, k, probs=False): logits) -class ClassificationHead(torch.nn.Module): - """ Classification Head for the transformer """ - - def __init__(self, class_size=5, embed_size=2048): - super(ClassificationHead, self).__init__() - self.class_size = class_size - self.embed_size = embed_size - # self.mlp1 = torch.nn.Linear(embed_size, embed_size) - # self.mlp2 = (torch.nn.Linear(embed_size, class_size)) - self.mlp = torch.nn.Linear(embed_size, class_size) - - def forward(self, hidden_state): - # hidden_state = F.relu(self.mlp1(hidden_state)) - # hidden_state = self.mlp2(hidden_state) - logits = self.mlp(hidden_state) - return logits - - def perturb_past(past, model, prev, args, classifier, good_index=None, stepsize=0.01, vocab_size=50257, original_probs=None, accumulated_hidden=None, true_past=None, diff --git a/examples/run_pplm_discrim_train.py b/examples/run_pplm_discrim_train.py new file mode 100644 index 0000000000..cc52234281 --- /dev/null +++ b/examples/run_pplm_discrim_train.py @@ -0,0 +1,582 @@ +#! /usr/bin/env python3 +# coding=utf-8 + +# This code is licensed under a non-commercial license. + +import argparse +import csv +import json +import math +import time + +import numpy as np +import torch +import torch.nn.functional as F +import torch.optim +import torch.optim as optim +import torch.utils.data as data +from nltk.tokenize.treebank import TreebankWordDetokenizer +from torchtext import data as torchtext_data +from torchtext import datasets +from transformers import GPT2Tokenizer, GPT2LMHeadModel + +torch.manual_seed(0) +np.random.seed(0) +EPSILON = 1e-10 +device = 'cpu' +example_sentence = "This is incredible! I love it, this is the best chicken I have ever had." +max_length_seq = 100 + + +class ClassificationHead(torch.nn.Module): + """Classification Head for transformer encoders""" + + def __init__(self, class_size, embed_size): + super(ClassificationHead, self).__init__() + self.class_size = class_size + self.embed_size = embed_size + # self.mlp1 = torch.nn.Linear(embed_size, embed_size) + # self.mlp2 = (torch.nn.Linear(embed_size, class_size)) + self.mlp = torch.nn.Linear(embed_size, class_size) + + def forward(self, hidden_state): + # hidden_state = F.relu(self.mlp1(hidden_state)) + # hidden_state = self.mlp2(hidden_state) + logits = self.mlp(hidden_state) + return logits + + +class Discriminator(torch.nn.Module): + """Transformer encoder followed by a Classification Head""" + + def __init__( + self, + class_size, + pretrained_model="gpt2-medium", + cached_mode=False + ): + super(Discriminator, self).__init__() + self.tokenizer = GPT2Tokenizer.from_pretrained(pretrained_model) + self.encoder = GPT2LMHeadModel.from_pretrained(pretrained_model) + self.embed_size = self.encoder.transformer.config.hidden_size + self.classifier_head = ClassificationHead( + class_size=class_size, + embed_size=self.embed_size + ) + self.cached_mode = cached_mode + + def get_classifier(self): + return self.classifier_head + + def train_custom(self): + for param in self.encoder.parameters(): + param.requires_grad = False + pass + self.classifier_head.train() + + def avg_representation(self, x): + mask = x.ne(0).unsqueeze(2).repeat( + 1, 1, self.embed_size + ).float().to(device).detach() + hidden, _ = self.encoder.transformer(x) + masked_hidden = hidden * mask + avg_hidden = torch.sum(masked_hidden, dim=1) / ( + torch.sum(mask, dim=1).detach() + EPSILON + ) + return avg_hidden + + def forward(self, x): + if self.cached_mode: + avg_hidden = x.to(device) + else: + avg_hidden = self.avg_representation(x) + + logits = self.classifier_head(avg_hidden) + probs = F.log_softmax(logits, dim=-1) + + return probs + + +class Dataset(data.Dataset): + def __init__(self, X, y): + """Reads source and target sequences from txt files.""" + self.X = X + self.y = y + + def __len__(self): + return len(self.X) + + def __getitem__(self, index): + """Returns one data pair (source and target).""" + data = {} + data['X'] = self.X[index] + data['y'] = self.y[index] + return data + + +def collate_fn(data): + def pad_sequences(sequences): + lengths = [len(seq) for seq in sequences] + + padded_sequences = torch.zeros( + len(sequences), + max(lengths) + ).long() # padding index 0 + + for i, seq in enumerate(sequences): + end = lengths[i] + padded_sequences[i, :end] = seq[:end] + + return padded_sequences, lengths + + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + x_batch, _ = pad_sequences(item_info['X']) + y_batch = torch.tensor(item_info['y'], dtype=torch.long) + + return x_batch, y_batch + + +def cached_collate_fn(data): + item_info = {} + for key in data[0].keys(): + item_info[key] = [d[key] for d in data] + + x_batch = torch.cat(item_info['X'], 0) + y_batch = torch.tensor(item_info['y'], dtype=torch.long) + + return x_batch, y_batch + + +def train_epoch(data_loader, discriminator, optimizer, + epoch=0, log_interval=10): + samples_so_far = 0 + discriminator.train_custom() + for batch_idx, (input_t, target_t) in enumerate(data_loader): + input_t, target_t = input_t.to(device), target_t.to(device) + + optimizer.zero_grad() + + output_t = discriminator(input_t) + loss = F.nll_loss(output_t, target_t) + loss.backward(retain_graph=True) + optimizer.step() + + samples_so_far += len(input_t) + + if batch_idx % log_interval == 0: + print( + 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch + 1, + samples_so_far, len(data_loader.dataset), + 100 * samples_so_far / len(data_loader.dataset), loss.item() + ) + ) + + +def evaluate_performance(data_loader, discriminator): + discriminator.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for input_t, target_t in data_loader: + input_t, target_t = input_t.to(device), target_t.to(device) + output_t = discriminator(input_t) + # sum up batch loss + test_loss += F.nll_loss(output_t, target_t, reduction='sum').item() + # get the index of the max log-probability + pred_t = output_t.argmax(dim=1, keepdim=True) + correct += pred_t.eq(target_t.view_as(pred_t)).sum().item() + + test_loss /= len(data_loader.dataset) + + print( + 'Performance on test set: ' + 'Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format( + test_loss, correct, len(data_loader.dataset), + 100. * correct / len(data_loader.dataset) + ) + ) + + +def predict(input_sentence, model, classes, cached=False): + input_t = model.tokenizer.encode(input_sentence) + input_t = torch.tensor([input_t], dtype=torch.long) + if cached: + input_t = model.avg_representation(input_t) + + log_probs = model(input_t).data.cpu().numpy().flatten().tolist() + print('Input sentence:', input_sentence) + print('Predictions:', ", ".join( + "{}: {:.4f}".format(c, math.exp(log_prob)) for c, log_prob in + zip(classes, log_probs) + )) + + +def get_cached_data_loader(dataset, batch_size, discriminator, shuffle=False): + data_loader = torch.utils.data.DataLoader(dataset=dataset, + batch_size=batch_size, + collate_fn=collate_fn) + + xs = [] + ys = [] + for batch_idx, (x, y) in enumerate(data_loader): + with torch.no_grad(): + x = x.to(device) + avg_rep = discriminator.avg_representation(x).cpu().detach() + avg_rep_list = torch.unbind(avg_rep.unsqueeze(1)) + xs += avg_rep_list + ys += y.cpu().numpy().tolist() + + data_loader = torch.utils.data.DataLoader( + dataset=Dataset(xs, ys), + batch_size=batch_size, + shuffle=shuffle, + collate_fn=cached_collate_fn) + + return data_loader + + +def train_discriminator( + dataset, dataset_fp=None, pretrained_model='gpt2-medium', + epochs=10, batch_size=64, log_interval=10, + save_model=False, cached=False, use_cuda=False): + if use_cuda: + global device + device = 'cuda' + + print('Preprocessing {} dataset...'.format(dataset)) + start = time.time() + + if dataset == 'SST': + idx2class = ["positive", "negative", "very positive", "very negative", + "neutral"] + class2idx = {c: i for i, c in enumerate(idx2class)} + + discriminator = Discriminator( + class_size=len(idx2class), + pretrained_model=pretrained_model, + cached_mode=cached + ).to(device) + + text = torchtext_data.Field() + label = torchtext_data.Field(sequential=False) + train_data, val_data, test_data = datasets.SST.splits( + text, + label, + fine_grained=True, + train_subtrees=True, + ) + + x = [] + y = [] + for i in range(len(train_data)): + seq = TreebankWordDetokenizer().detokenize( + vars(train_data[i])["text"] + ) + seq = discriminator.tokenizer.encode(seq) + seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) + x.append(seq) + y.append(class2idx[vars(train_data[i])["label"]]) + train_dataset = Dataset(x, y) + + test_x = [] + test_y = [] + for i in range(len(test_data)): + seq = TreebankWordDetokenizer().detokenize( + vars(test_data[i])["text"] + ) + seq = discriminator.tokenizer.encode(seq) + seq = torch.tensor([50256] + seq, device=device, dtype=torch.long) + test_x.append(seq) + test_y.append(class2idx[vars(test_data[i])["label"]]) + test_dataset = Dataset(test_x, test_y) + + discriminator_meta = { + "class_size": len(idx2class), + "embed_size": discriminator.embed_size, + "pretrained_model": pretrained_model, + "class_vocab": class2idx, + "default_class": 2, + } + + elif dataset == 'clickbait': + idx2class = ["non_clickbait", "clickbait"] + class2idx = {c: i for i, c in enumerate(idx2class)} + + discriminator = Discriminator( + class_size=len(idx2class), + pretrained_model=pretrained_model, + cached_mode=cached + ).to(device) + + with open("datasets/clickbait/clickbait_train_prefix.txt") as f: + data = [] + for i, line in enumerate(f): + try: + data.append(eval(line)) + except: + print('Error evaluating line {}: {}'.format( + i, line + )) + continue + x = [] + y = [] + y = [] + for i, d in enumerate(data): + try: + seq = discriminator.tokenizer.encode(d["text"]) + + if len(seq) < max_length_seq: + seq = torch.tensor( + [50256] + seq, device=device, dtype=torch.long + ) + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue + x.append(seq) + y.append(d['label']) + except: + print("Error tokenizing line {}, skipping it".format(i)) + pass + + full_dataset = Dataset(x, y) + train_size = int(0.9 * len(full_dataset)) + test_size = len(full_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + full_dataset, [train_size, test_size] + ) + + discriminator_meta = { + "class_size": len(idx2class), + "embed_size": discriminator.embed_size, + "pretrained_model": pretrained_model, + "class_vocab": class2idx, + "default_class": 1, + } + + elif dataset == 'toxic': + idx2class = ["non_toxic", "toxic"] + class2idx = {c: i for i, c in enumerate(idx2class)} + + discriminator = Discriminator( + class_size=len(idx2class), + pretrained_model=pretrained_model, + cached_mode=cached + ).to(device) + + with open("datasets/toxic/toxic_train.txt") as f: + data = [] + for i, line in enumerate(f): + try: + data.append(eval(line)) + except: + print('Error evaluating line {}: {}'.format( + i, line + )) + continue + + x = [] + y = [] + for i, d in enumerate(data): + try: + seq = discriminator.tokenizer.encode(d["text"]) + + if len(seq) < max_length_seq: + seq = torch.tensor( + [50256] + seq, device=device, dtype=torch.long + ) + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue + x.append(seq) + y.append(int(np.sum(d['label']) > 0)) + except: + print("Error tokenizing line {}, skipping it".format(i)) + pass + + full_dataset = Dataset(x, y) + train_size = int(0.9 * len(full_dataset)) + test_size = len(full_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + full_dataset, [train_size, test_size] + ) + + discriminator_meta = { + "class_size": len(idx2class), + "embed_size": discriminator.embed_size, + "pretrained_model": pretrained_model, + "class_vocab": class2idx, + "default_class": 0, + } + + else: # if dataset == 'generic': + # This assumes the input dataset is a TSV with the following structure: + # class \t text + + if dataset_fp is None: + raise ValueError('When generic dataset is selected, ' + 'dataset_fp needs to be specified aswell.') + + classes = set() + with open(dataset_fp) as f: + csv_reader = csv.reader(f, delimiter='\t') + for row in csv_reader: + classes.add(row[0]) + + idx2class = sorted(classes) + class2idx = {c: i for i, c in enumerate(idx2class)} + + discriminator = Discriminator( + class_size=len(idx2class), + pretrained_model=pretrained_model, + cached_mode=cached + ).to(device) + + x = [] + y = [] + with open(dataset_fp) as f: + csv_reader = csv.reader(f, delimiter='\t') + for i, row in enumerate(csv_reader): + label = row[0] + text = row[1] + + try: + seq = discriminator.tokenizer.encode(text) + if (len(seq) < max_length_seq): + seq = torch.tensor( + [50256] + seq, + device=device, + dtype=torch.long + ) + + else: + print("Line {} is longer than maximum length {}".format( + i, max_length_seq + )) + continue + + x.append(seq) + y.append(class2idx[label]) + + except: + print("Error tokenizing line {}, skipping it".format(i)) + pass + + full_dataset = Dataset(x, y) + train_size = int(0.9 * len(full_dataset)) + test_size = len(full_dataset) - train_size + train_dataset, test_dataset = torch.utils.data.random_split( + full_dataset, + [train_size, test_size] + ) + + discriminator_meta = { + "class_size": len(idx2class), + "embed_size": discriminator.embed_size, + "pretrained_model": pretrained_model, + "class_vocab": class2idx, + "default_class": 0, + } + + end = time.time() + print('Preprocessed {} data points'.format( + len(train_dataset) + len(test_dataset)) + ) + print("Data preprocessing took: {:.3f}s".format(end - start)) + + if cached: + start = time.time() + + train_loader = get_cached_data_loader( + train_dataset, batch_size, discriminator, shuffle=True + ) + + test_loader = get_cached_data_loader( + test_dataset, batch_size, discriminator + ) + + end = time.time() + print("Building representation cache took: {:.3f}s".format(end - start)) + + else: + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=batch_size, + shuffle=True, + collate_fn=collate_fn) + test_loader = torch.utils.data.DataLoader(dataset=test_dataset, + batch_size=batch_size, + collate_fn=collate_fn) + + if save_model: + with open("{}_classifier_head_meta.json".format(dataset), + "w") as meta_file: + json.dump(discriminator_meta, meta_file) + + optimizer = optim.Adam(discriminator.parameters(), lr=0.0001) + + for epoch in range(epochs): + start = time.time() + print('\nEpoch', epoch + 1) + + train_epoch( + discriminator=discriminator, + data_loader=train_loader, + optimizer=optimizer, + epoch=epoch, + log_interval=log_interval + ) + evaluate_performance( + data_loader=test_loader, + discriminator=discriminator + ) + + end = time.time() + print("Epoch took: {:.3f}s".format(end - start)) + + print("\nExample prediction") + predict(example_sentence, discriminator, idx2class, cached) + + if save_model: + # torch.save(discriminator.state_dict(), + # "{}_discriminator_{}.pt".format( + # args.dataset, epoch + # )) + torch.save(discriminator.get_classifier().state_dict(), + "{}_classifier_head_epoch_{}.pt".format(dataset, epoch)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Train a discriminator on top of GPT-2 representations') + parser.add_argument('--dataset', type=str, default='SST', + choices=('SST', 'clickbait', 'toxic', 'generic'), + help='dataset to train the discriminator on.' + 'In case of generic, the dataset is expected' + 'to be a TSBV file with structure: class \\t text') + parser.add_argument('--dataset_fp', type=str, default='', + help='File path of the dataset to use. ' + 'Needed only in case of generic datadset') + parser.add_argument('--pretrained_model', type=str, default='gpt2-medium', + help='Pretrained model to use as encoder') + parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='Number of training epochs') + parser.add_argument('--batch_size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument('--log_interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save_model', action='store_true', + help='whether to save the model') + parser.add_argument('--cached', action='store_true', + help='whether to cache the input representations') + parser.add_argument('--use_cuda', action='store_true', + help='use to turn on cuda') + args = parser.parse_args() + + train_discriminator(**(vars(args)))