This commit is contained in:
thomwolf 2018-11-04 15:17:55 +01:00
parent 834b485b2e
commit 6b0da96b4b
2 changed files with 15 additions and 12 deletions

View File

@ -69,7 +69,7 @@ class InputFeatures(object):
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
@ -95,8 +95,8 @@ class DataProcessor(object):
for line in reader:
lines.append(line)
return lines
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer):
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
"""Loads a data file into a list of `InputBatch`s."""
label_map = {}
@ -380,7 +379,7 @@ def main():
parser.add_argument("--do_lower_case",
default=False,
action='store_true',
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
help="Whether to lower case the input text. True for uncased models, False for cased models.")
parser.add_argument("--max_seq_length",
default=128,
type=int,
@ -424,6 +423,10 @@ def main():
default=False,
action='store_true',
help="Whether not to use CUDA when available")
parser.add_argument("--accumulate_gradients",
type=int,
default=1,
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
parser.add_argument("--local_rank",
type=int,
default=-1,
@ -448,12 +451,12 @@ def main():
n_gpu = 1
# print("Initializing the distributed backend: NCCL")
print("device", device, "n_gpu", n_gpu)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
if not args.do_train and not args.do_eval:
raise ValueError("At least one of `do_train` or `do_eval` must be True.")

View File

@ -18,15 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
import collections
import logging
import json
import math
import os
from tqdm import tqdm, trange
import six
import random
from tqdm import tqdm, trange
import numpy as np
import torch