clean up
This commit is contained in:
parent
834b485b2e
commit
6b0da96b4b
|
@ -69,7 +69,7 @@ class InputFeatures(object):
|
|||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.label_id = label_id
|
||||
|
||||
|
||||
|
||||
class DataProcessor(object):
|
||||
"""Base class for data converters for sequence classification data sets."""
|
||||
|
@ -95,8 +95,8 @@ class DataProcessor(object):
|
|||
for line in reader:
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
|
||||
|
||||
|
||||
class MrpcProcessor(DataProcessor):
|
||||
"""Processor for the MRPC data set (GLUE version)."""
|
||||
|
||||
|
@ -190,10 +190,9 @@ class ColaProcessor(DataProcessor):
|
|||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer):
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {}
|
||||
|
@ -380,7 +379,7 @@ def main():
|
|||
parser.add_argument("--do_lower_case",
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Whether to lower case the input text. Should be True for uncased models and False for cased models.")
|
||||
help="Whether to lower case the input text. True for uncased models, False for cased models.")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
|
@ -424,6 +423,10 @@ def main():
|
|||
default=False,
|
||||
action='store_true',
|
||||
help="Whether not to use CUDA when available")
|
||||
parser.add_argument("--accumulate_gradients",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of steps to accumulate gradient on (divide the single step batch_size)")
|
||||
parser.add_argument("--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
|
@ -448,12 +451,12 @@ def main():
|
|||
n_gpu = 1
|
||||
# print("Initializing the distributed backend: NCCL")
|
||||
print("device", device, "n_gpu", n_gpu)
|
||||
|
||||
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if n_gpu>0: torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
if not args.do_train and not args.do_eval:
|
||||
raise ValueError("At least one of `do_train` or `do_eval` must be True.")
|
||||
|
||||
|
|
|
@ -18,15 +18,15 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
import six
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
|
Loading…
Reference in New Issue