diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 63cc2b4b9c..7c00e4833d 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -222,6 +222,10 @@ def main(): elif n_gpu > 1: model = torch.nn.DataParallel(model) + global_step = 0 + nb_tr_steps = 0 + tr_loss = 0 + if args.do_train: if args.local_rank in [-1, 0]: tb_writer = SummaryWriter() @@ -293,10 +297,6 @@ def main(): warmup=args.warmup_proportion, t_total=num_train_optimization_steps) - global_step = 0 - nb_tr_steps = 0 - tr_loss = 0 - logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_examples)) logger.info(" Batch size = %d", args.train_batch_size)