fixes + clean up + mask is long
This commit is contained in:
parent
3ddff783c1
commit
d69b0b0e90
10
run_squad.py
10
run_squad.py
|
@ -24,8 +24,8 @@ import logging
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import six
|
||||
import random
|
||||
import six
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
|
@ -750,7 +750,7 @@ def main():
|
|||
n_gpu = 1
|
||||
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
torch.distributed.init_process_group(backend='nccl')
|
||||
print("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
||||
logger.info("device", device, "n_gpu", n_gpu, "distributed training", bool(args.local_rank != -1))
|
||||
|
||||
if args.accumulate_gradients < 1:
|
||||
raise ValueError("Invalid accumulate_gradients parameter: {}, should be >= 1".format(
|
||||
|
@ -855,7 +855,7 @@ def main():
|
|||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
start_positions = start_positions.to(device)
|
||||
end_positions = start_positions.to(device)
|
||||
|
@ -904,12 +904,12 @@ def main():
|
|||
model.eval()
|
||||
all_results = []
|
||||
logger.info("Start evaluating")
|
||||
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
|
||||
for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
if len(all_results) % 1000 == 0:
|
||||
logger.info("Processing example: %d" % (len(all_results)))
|
||||
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.float().to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
|
||||
start_logits, end_logits = model(input_ids, segment_ids, input_mask)
|
||||
|
|
Loading…
Reference in New Issue