fix run_squad example

This commit is contained in:
thomwolf 2019-02-11 14:06:32 +01:00
parent eebc8abbe2
commit af62cc5f20
1 changed files with 1 additions and 1 deletions

View File

@ -881,7 +881,7 @@ def main():
train_examples = read_squad_examples(
input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
num_train_optimization_steps = int(
len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
if args.local_rank != -1:
num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()