From f63ff536adfee0505a64d02649ef4b55dde614bb Mon Sep 17 00:00:00 2001 From: Rabeeh KARIMI Date: Sat, 20 Jul 2019 12:43:07 +0200 Subject: [PATCH] fixed version issues in run_openai_gpt --- .../single_model_scripts/run_openai_gpt.py | 37 +++++++++++++------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/examples/single_model_scripts/run_openai_gpt.py b/examples/single_model_scripts/run_openai_gpt.py index b2e85271cb..af737b953e 100644 --- a/examples/single_model_scripts/run_openai_gpt.py +++ b/examples/single_model_scripts/run_openai_gpt.py @@ -40,7 +40,8 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset) from pytorch_transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, - AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME) + AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME, + WarmupLinearSchedule) ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" @@ -104,9 +105,18 @@ def main(): parser.add_argument('--num_train_epochs', type=int, default=3) parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--eval_batch_size', type=int, default=16) + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") parser.add_argument('--max_grad_norm', type=int, default=1) + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training \ + steps to perform. Override num_train_epochs.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before\ + performing a backward/update pass.") parser.add_argument('--learning_rate', type=float, default=6.25e-5) - parser.add_argument('--warmup_proportion', type=float, default=0.002) + parser.add_argument("--warmup_steps", default=0, type=int, + help="Linear warmup over warmup_steps.") parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--lm_coef', type=float, default=0.9) @@ -184,19 +194,22 @@ def main(): # Prepare optimizer if args.do_train: + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps //\ + (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader)\ + // args.gradient_accumulation_steps * args.num_train_epochs + param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [ {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] - num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs - optimizer = AdamW(optimizer_grouped_parameters, - lr=args.learning_rate, - warmup=args.warmup_proportion, - max_grad_norm=args.max_grad_norm, - weight_decay=args.weight_decay, - t_total=num_train_optimization_steps) + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) if args.do_train: nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None @@ -211,12 +224,13 @@ def main(): losses = model(input_ids, mc_token_ids, lm_labels, mc_labels) loss = args.lm_coef * losses[0] + losses[1] loss.backward() + scheduler.step() optimizer.step() optimizer.zero_grad() tr_loss += loss.item() exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() nb_tr_steps += 1 - tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, optimizer.get_lr()[0]) + tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0]) # Save a trained model if args.do_train: @@ -244,8 +258,7 @@ def main(): batch = tuple(t.to(device) for t in batch) input_ids, mc_token_ids, lm_labels, mc_labels = batch with torch.no_grad(): - _, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels) - _, mc_logits = model(input_ids, mc_token_ids) + _, mc_loss, _, mc_logits = model(input_ids, mc_token_ids, lm_labels, mc_labels) mc_logits = mc_logits.detach().cpu().numpy() mc_labels = mc_labels.to('cpu').numpy()