451 lines
15 KiB
Python
451 lines
15 KiB
Python
# Copyright 2022 - Intel Corp. All rights reserved.
|
||
# Authors: Mayank Kumar Raunak, Javier Turek, Nicole Beckage
|
||
|
||
"""
|
||
Implementation of a new method for fine-tuning transformer models that we call
|
||
Information Gain Filtration 'IGF' on WikiText data set and compared the results
|
||
with the standard fine-tuning method
|
||
|
||
Steps followed in the code:
|
||
|
||
1) Generate a objective dataset of pairs (X, IG(X)). IG(X)--Informativeness of context 'X'.
|
||
Our IG (information gain) model is learning to predict the ‘informativeness’ of a particular
|
||
context. Informativeness is the change in metric between the model’s accuracy on an
|
||
objective set before and after seeing that context. For casual language modeling, the
|
||
metric is perplexity.
|
||
|
||
2) A secondary learner is trained to infer a function approximation for IG using the dataset
|
||
created in (1).
|
||
|
||
3) The learner created in (2) is used to inform the fine-tuning process and filter out low informative samples.
|
||
|
||
Last, a plot is generated to compare the performance of IGF to standard fine-tuning without any filtering
|
||
|
||
"""
|
||
|
||
# Prerequisite libraries:
|
||
|
||
import argparse
|
||
import random
|
||
|
||
import joblib
|
||
import numpy as np
|
||
import torch
|
||
from igf.igf import (
|
||
SecondaryLearner,
|
||
collect_objective_set,
|
||
compute_perplexity,
|
||
generate_datasets,
|
||
load_gpt2,
|
||
recopy_gpt2,
|
||
set_seed,
|
||
train_secondary_learner,
|
||
)
|
||
from torch.utils.data import DataLoader, RandomSampler
|
||
|
||
from transformers import GPT2LMHeadModel
|
||
|
||
|
||
def generate_n_pairs(
|
||
context_len=32,
|
||
max_steps=10,
|
||
size_objective_set=100,
|
||
min_len=1026,
|
||
trim=True,
|
||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||
igf_data_file="igf_context_pairs.jbl",
|
||
):
|
||
"""
|
||
Collecting *n* pairs for training the secondary learner
|
||
Args:
|
||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||
than this will be truncated, sequences shorter will be padded
|
||
max_steps: To calculate training epochs of secondary learner
|
||
size_objective_set: size of objective data set used to create (X,IG(X)) pairs which is the training data for secondary learner
|
||
min_len: The minimum length of the article to be used as objective set
|
||
trim: If True truncate the context if it exceeds context length
|
||
data_file: Tokenized data set split for training and evaluation of model
|
||
igf_data_file: file to store (I,IG(X)) paired data set to train secondary learner
|
||
|
||
Returns:
|
||
Data stored in igf_data_file
|
||
|
||
"""
|
||
# generates same data everytime
|
||
set_seed(3)
|
||
# generate train_data and objective_set
|
||
train_data, objective_set = generate_datasets(
|
||
context_len, data_file, number=size_objective_set, min_len=1026, trim=True
|
||
)
|
||
# keeps model same across runs
|
||
set_seed(4)
|
||
# model, lm_optimizer, lm_scheduler = recopy_gpt2(model, device, max_steps) # store original model weights
|
||
# can we train on GPU?
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
|
||
# load pretrained model
|
||
model = load_gpt2("openai-community/gpt2").to(device)
|
||
print("computing perplexity on objective set")
|
||
orig_perp = compute_perplexity(model, objective_set, context_len).item()
|
||
print("perplexity on objective set:", orig_perp)
|
||
|
||
# collect igf pairs and save to file demo.jbl
|
||
collect_objective_set(model, orig_perp, context_len, train_data, objective_set, max_steps, device, igf_data_file)
|
||
|
||
# clean up, delete model and data we don't need anymore
|
||
del model, train_data, objective_set
|
||
torch.cuda.empty_cache()
|
||
|
||
|
||
def training_secondary_learner(
|
||
secondary_learner_train_data,
|
||
secondary_learner_max_epochs=15,
|
||
secondary_learner_batch_size=128,
|
||
eval_freq=100,
|
||
igf_model_path="igf_model.pt",
|
||
):
|
||
"""
|
||
Train the secondary learner
|
||
|
||
Args:
|
||
secondary_learner_train_data: Data set with (X,IG(X)) pairs to train secondary learner where IG(X) - measure of informativeness and X- context
|
||
secondary_learner_max_epochs: Number of epochs to train secondary learner
|
||
secondary_learner_batch_size: Batch size to train secondary learner
|
||
eval_freq (object): secondary model evaluation can be triggered at eval_freq
|
||
igf_model_path: path to store trained secondary learner
|
||
|
||
Returns:
|
||
Trained secondary learner
|
||
"""
|
||
|
||
set_seed(42)
|
||
|
||
# Load pre-trained model
|
||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||
|
||
# Initialize secondary learner to use embedding weights of model
|
||
secondary_learner = SecondaryLearner(model)
|
||
|
||
# Train secondary learner
|
||
secondary_learner = train_secondary_learner(
|
||
secondary_learner,
|
||
secondary_learner_train_data,
|
||
max_epochs=secondary_learner_max_epochs,
|
||
batch_size=secondary_learner_batch_size,
|
||
eval_freq=100,
|
||
igf_model_path=igf_model_path,
|
||
)
|
||
|
||
del model, secondary_learner_train_data
|
||
torch.cuda.empty_cache()
|
||
|
||
return secondary_learner
|
||
|
||
|
||
def finetune(
|
||
model,
|
||
train_dataset,
|
||
test_dataset,
|
||
context_len=32,
|
||
max_steps=1000,
|
||
batch_size=16,
|
||
threshold=1.0,
|
||
recopy_model=recopy_gpt2,
|
||
secondary_learner=None,
|
||
eval_interval=10,
|
||
finetuned_model_name="openai-community/gpt2_finetuned.pt",
|
||
):
|
||
"""
|
||
fine-tune with IGF if secondary_learner is not None, else standard fine-tuning
|
||
|
||
Args:
|
||
model: pre-trained GPT-2 model
|
||
train_dataset: Data set to train GPT-2 model
|
||
test_dataset: Evaluate GPT-2 model
|
||
context_len: The maximum total input sequence length after tokenization. Sequences longer
|
||
than this will be truncated, sequences shorter will be padded
|
||
max_steps: To calculate training epochs
|
||
batch_size: Batch size to train GPT-2 model
|
||
threshold: The threshold value used by secondary learner to filter the train_data and allow only"
|
||
informative data as input to the model
|
||
recopy_model: Reset the model to the original pretrained GPT-2 weights after each iteration
|
||
secondary_learner: Selection of IGF as fine-tuning method if not None
|
||
eval_interval: number of batches after which decay the selectivity of our secondary learner filter from
|
||
1 standard deviation above average to 1 below average
|
||
fine-tuned_model_name: name of the final final-tuned GPT-2 model
|
||
|
||
Returns:
|
||
Fine-tuned GPT-2 model
|
||
|
||
"""
|
||
|
||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||
train_sampler = RandomSampler(train_dataset)
|
||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler)
|
||
|
||
num_train_epochs = max_steps // (len(train_dataset)) + 1
|
||
global_step = 0
|
||
context = torch.zeros((1, context_len), dtype=torch.long, device=device)
|
||
model, lm_optimizer, lm_scheduler = recopy_model(model, device, max_steps)
|
||
|
||
model.train()
|
||
if secondary_learner is not None:
|
||
secondary_learner.to(device)
|
||
secondary_learner.eval()
|
||
contexts = []
|
||
examples = 0
|
||
|
||
observed_qs = []
|
||
test_perps = []
|
||
|
||
# Compute the performance of the transformer model at the beginning
|
||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||
test_perps.append(real_perp)
|
||
print("Test perplexity, step", global_step, ":", real_perp)
|
||
for epoch in range(int(num_train_epochs)):
|
||
for step, example in enumerate(train_dataloader):
|
||
torch.cuda.empty_cache()
|
||
start = random.randint(0, example.size(2) - context_len - 1)
|
||
context[0, :] = example[0, 0, start : start + context_len]
|
||
lm_optimizer.zero_grad()
|
||
outputs = model(context, labels=context)
|
||
do_backprop = True
|
||
|
||
if secondary_learner is not None:
|
||
predicted_q = secondary_learner.forward(
|
||
torch.tensor(context, dtype=torch.long, device=device).unsqueeze(0)
|
||
)[0].item()
|
||
observed_qs.append(float(predicted_q))
|
||
|
||
# Here we implement the simple non-constant threshold for the predicted IG(X) value
|
||
# We will decay the selectivity of our secondary learner filter from
|
||
# 1 standard deviation above average to 1 below average after 10 batches.
|
||
|
||
if global_step == 10:
|
||
threshold = -1
|
||
if predicted_q < threshold:
|
||
do_backprop = False
|
||
|
||
# If we passed the filter, add the context to the batch!
|
||
if do_backprop:
|
||
contexts.append(np.array(context.cpu()))
|
||
lm_loss = outputs[0]
|
||
lm_loss.backward()
|
||
examples += 1
|
||
|
||
del outputs
|
||
|
||
# Once the batch is filled with enough contexts, backprop on the batch.
|
||
if examples == batch_size:
|
||
torch.cuda.empty_cache()
|
||
examples = 0
|
||
# Do LM backprop
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), 3.0)
|
||
lm_optimizer.step()
|
||
lm_scheduler.step() # Update learning rate schedule
|
||
global_step += 1
|
||
# Compute the performance of the transformer model at this batch
|
||
if global_step % eval_interval == 0:
|
||
real_perp = compute_perplexity(model, test_dataset, context_len)
|
||
test_perps.append(real_perp)
|
||
|
||
print("Test perplexity, step", global_step, ":", real_perp)
|
||
# Break out of the loop after 60 batches
|
||
if max_steps > 0 and global_step > 60:
|
||
break
|
||
if max_steps > 0 and global_step > 60:
|
||
break
|
||
|
||
# save finetuned transformer model
|
||
torch.save(model.state_dict(), finetuned_model_name)
|
||
torch.cuda.empty_cache()
|
||
# Do some cleaning up so we can reinitialize for the next run of this function
|
||
del lm_optimizer
|
||
del lm_scheduler
|
||
return model
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description="Fine-tune a transformer model with IGF on a language modeling task")
|
||
|
||
# Required parameters
|
||
parser.add_argument(
|
||
"--data_dir",
|
||
default=None,
|
||
type=str,
|
||
required=True,
|
||
help="The input data dir. Should contain data files for WikiText.",
|
||
)
|
||
parser.add_argument(
|
||
"--model_name_or_path",
|
||
default=None,
|
||
type=str,
|
||
required=True,
|
||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||
)
|
||
parser.add_argument(
|
||
"--data_file",
|
||
type=str,
|
||
default=None,
|
||
help=(
|
||
"A jbl file containing tokenized data which can be split as objective dataset, "
|
||
"train_dataset and test_dataset."
|
||
),
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--igf_data_file",
|
||
type=str,
|
||
default=None,
|
||
help="A jbl file containing the context and information gain pairs to train secondary learner.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--output_dir",
|
||
default=None,
|
||
type=str,
|
||
required=True,
|
||
help="The output directory where the final fine-tuned model is stored.",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--tokenizer_name",
|
||
default=None,
|
||
type=str,
|
||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||
)
|
||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||
|
||
parser.add_argument(
|
||
"--context_len",
|
||
default=32,
|
||
type=int,
|
||
help=(
|
||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||
"than this will be truncated, sequences shorter will be padded."
|
||
),
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--size_objective_set",
|
||
default=100,
|
||
type=int,
|
||
help="number of articles that are long enough to be used as our objective set",
|
||
)
|
||
parser.add_argument(
|
||
"--eval_freq", default=100, type=int, help="secondary model evaluation is triggered at eval_freq"
|
||
)
|
||
|
||
parser.add_argument("--max_steps", default=1000, type=int, help="To calculate training epochs")
|
||
|
||
parser.add_argument(
|
||
"--secondary_learner_batch_size",
|
||
default=128,
|
||
type=int,
|
||
help="batch size of training data for secondary learner",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--batch_size",
|
||
default=16,
|
||
type=int,
|
||
help="batch size of training data of language model(openai-community/gpt2) ",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--eval_interval",
|
||
default=10,
|
||
type=int,
|
||
help=(
|
||
"decay the selectivity of our secondary learner filter from "
|
||
"1 standard deviation above average to 1 below average after 10 batches"
|
||
),
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--number", default=100, type=int, help="The number of examples split to be used as objective_set/test_data"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--min_len", default=1026, type=int, help="The minimum length of the article to be used as objective set"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--secondary_learner_max_epochs", default=15, type=int, help="number of epochs to train secondary learner"
|
||
)
|
||
|
||
parser.add_argument("--trim", default=True, type=bool, help="truncate the example if it exceeds context length")
|
||
|
||
parser.add_argument(
|
||
"--threshold",
|
||
default=1.0,
|
||
type=float,
|
||
help=(
|
||
"The threshold value used by secondary learner to filter the train_data and allow only"
|
||
" informative data as input to the model"
|
||
),
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--finetuned_model_name", default="openai-community/gpt2_finetuned.pt", type=str, help="finetuned_model_name"
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--recopy_model",
|
||
default=recopy_gpt2,
|
||
type=str,
|
||
help="Reset the model to the original pretrained GPT-2 weights after each iteration",
|
||
)
|
||
|
||
# function calls
|
||
# Collecting *n* pairs of context and information gain(X, IG(X)) for training the secondary learner
|
||
generate_n_pairs(
|
||
context_len=32,
|
||
max_steps=10,
|
||
size_objective_set=100,
|
||
min_len=1026,
|
||
trim=True,
|
||
data_file="data/tokenized_stories_train_wikitext103.jbl",
|
||
igf_data_file="igf_context_pairs.jbl",
|
||
)
|
||
|
||
# Load train data for secondary learner
|
||
secondary_learner_train_data = joblib.load("data/IGF_values.jbl")
|
||
|
||
# Train secondary learner
|
||
secondary_learner = training_secondary_learner(
|
||
secondary_learner_train_data,
|
||
secondary_learner_max_epochs=15,
|
||
secondary_learner_batch_size=128,
|
||
eval_freq=100,
|
||
igf_model_path="igf_model.pt",
|
||
)
|
||
|
||
# load pretrained openai-community/gpt2 model
|
||
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
|
||
set_seed(42)
|
||
|
||
# Generate train and test data to train and evaluate openai-community/gpt2 model
|
||
train_dataset, test_dataset = generate_datasets(
|
||
context_len=32, file="data/tokenized_stories_train_wikitext103.jbl", number=100, min_len=1026, trim=True
|
||
)
|
||
|
||
# fine-tuning of the openai-community/gpt2 model using igf (Information Gain Filtration)
|
||
finetune(
|
||
model,
|
||
train_dataset,
|
||
test_dataset,
|
||
context_len=32,
|
||
max_steps=1000,
|
||
batch_size=16,
|
||
threshold=1.0,
|
||
recopy_model=recopy_gpt2,
|
||
secondary_learner=secondary_learner,
|
||
eval_interval=10,
|
||
finetuned_model_name="openai-community/gpt2_finetuned.pt",
|
||
)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|