fixes + clean up + mask is long

This commit is contained in:
thomwolf 2018-11-04 21:26:54 +01:00
parent 3ddff783c1
commit d69b0b0e90
1 changed files with 5 additions and 5 deletions

View File

@ -24,8 +24,8 @@ import logging
import json
import math
import os
import six
import random
import six
from tqdm import tqdm, trange
import numpy as np
@ -750,7 +750,7 @@ def main():
n_gpu = 1
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch.distributed.init_process_group(backend='nccl')
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
if args.accumulate_gradients < 1:
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
@ -855,7 +855,7 @@ def main():
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
start_positions = start_positions.to(device)
end_positions = start_positions.to(device)
@ -904,12 +904,12 @@ def main():
model.eval()
all_results = []
logger.info("Start evaluating")
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results)))
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
input_mask = input_mask.to(device)
segment_ids = segment_ids.to(device)
start_logits, end_logits = model(input_ids, segment_ids, input_mask)