Fixing the issues reported in https://github.com/huggingface/pytorch-pretrained-BERT/issues/556
Reason for issue was that optimzation steps where computed from example size, which is different from actual size of dataloader when an example is chunked into multiple instances. Solution in this pull request is to compute num_optimization_steps directly from len(data_loader).
This commit is contained in:
parent
3fc63f126d
commit
3bf3f9596f
|
@ -25,6 +25,7 @@ import random
|
|||
import sys
|
||||
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
|
@ -739,8 +740,25 @@ def main():
|
|||
num_train_optimization_steps = None
|
||||
if args.do_train:
|
||||
train_examples = processor.get_train_examples(args.data_dir)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
||||
if args.local_rank != -1:
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
|
@ -798,27 +816,10 @@ def main():
|
|||
nb_tr_steps = 0
|
||||
tr_loss = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
|
|
|
@ -190,7 +190,7 @@ def main():
|
|||
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
|
||||
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
|
||||
]
|
||||
num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size
|
||||
num_train_optimization_steps = len(train_dataloader) * args.num_train_epochs
|
||||
optimizer = OpenAIAdam(optimizer_grouped_parameters,
|
||||
lr=args.learning_rate,
|
||||
warmup=args.warmup_proportion,
|
||||
|
|
|
@ -899,8 +899,37 @@ def main():
|
|||
if args.do_train:
|
||||
train_examples = read_squad_examples(
|
||||
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
|
||||
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
|
||||
train_features = None
|
||||
try:
|
||||
with open(cached_train_features_file, "rb") as reader:
|
||||
train_features = pickle.load(reader)
|
||||
except:
|
||||
train_features = convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=True)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "wb") as writer:
|
||||
pickle.dump(train_features, writer)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
||||
if args.local_rank != -1:
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
|
@ -960,41 +989,11 @@ def main():
|
|||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
|
||||
list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
|
||||
train_features = None
|
||||
try:
|
||||
with open(cached_train_features_file, "rb") as reader:
|
||||
train_features = pickle.load(reader)
|
||||
except:
|
||||
train_features = convert_examples_to_features(
|
||||
examples=train_examples,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=True)
|
||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||
with open(cached_train_features_file, "wb") as writer:
|
||||
pickle.dump(train_features, writer)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num orig examples = %d", len(train_examples))
|
||||
logger.info(" Num split examples = %d", len(train_features))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
|
|
|
@ -362,8 +362,20 @@ def main():
|
|||
num_train_optimization_steps = None
|
||||
if args.do_train:
|
||||
train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
|
||||
num_train_optimization_steps = int(
|
||||
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, tokenizer, args.max_seq_length, True)
|
||||
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
|
||||
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
num_train_optimization_steps = len(train_dataloader) / args.gradient_accumulation_steps * args.num_train_epochs
|
||||
if args.local_rank != -1:
|
||||
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()
|
||||
|
||||
|
@ -422,22 +434,10 @@ def main():
|
|||
|
||||
global_step = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, tokenizer, args.max_seq_length, True)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
logger.info(" Num steps = %d", num_train_optimization_steps)
|
||||
all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
|
||||
all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
|
||||
all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
|
||||
all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
else:
|
||||
train_sampler = DistributedSampler(train_data)
|
||||
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
model.train()
|
||||
for _ in trange(int(args.num_train_epochs), desc="Epoch"):
|
||||
|
|
Loading…
Reference in New Issue