Merge pull request #388 from ananyahjha93/master
Added remaining GLUE tasks to 'run_classifier.py'
This commit is contained in:
commit
694e2117f3
51
README.md
51
README.md
|
@ -927,11 +927,60 @@ Where `$THIS_MACHINE_INDEX` is an sequential index assigned to each of your mach
|
|||
|
||||
We showcase several fine-tuning examples based on (and extended from) [the original implementation](https://github.com/google-research/bert/):
|
||||
|
||||
- a *sequence-level classifier* on the MRPC classification corpus,
|
||||
- a *sequence-level classifier* on nine different GLUE tasks,
|
||||
- a *token-level classifier* on the question answering dataset SQuAD, and
|
||||
- a *sequence-level multiple-choice classifier* on the SWAG classification corpus.
|
||||
- a *BERT language model* on another target corpus
|
||||
|
||||
#### GLUE results on dev set
|
||||
|
||||
We get the following results on the dev set of GLUE benchmark with an uncased BERT base
|
||||
model. All experiments were run on a P100 GPU with a batch size of 32.
|
||||
|
||||
| Task | Metric | Result |
|
||||
|-|-|-|
|
||||
| CoLA | Matthew's corr. | 57.29 |
|
||||
| SST-2 | accuracy | 93.00 |
|
||||
| MRPC | F1/accuracy | 88.85/83.82 |
|
||||
| STS-B | Pearson/Spearman corr. | 89.70/89.37 |
|
||||
| QQP | accuracy/F1 | 90.72/87.41 |
|
||||
| MNLI | matched acc./mismatched acc.| 83.95/84.39 |
|
||||
| QNLI | accuracy | 89.04 |
|
||||
| RTE | accuracy | 61.01 |
|
||||
| WNLI | accuracy | 53.52 |
|
||||
|
||||
Some of these results are significantly different from the ones reported on the test set
|
||||
of GLUE benchmark on the website. For QQP and WNLI, please refer to [FAQ #12](https://gluebenchmark.com/faq) on the webite.
|
||||
|
||||
Before running anyone of these GLUE tasks you should download the
|
||||
[GLUE data](https://gluebenchmark.com/tasks) by running
|
||||
[this script](https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
||||
and unpack it to some directory `$GLUE_DIR`.
|
||||
|
||||
```shell
|
||||
export GLUE_DIR=/path/to/glue
|
||||
export TASK_NAME=MRPC
|
||||
|
||||
python run_classifier.py \
|
||||
--task_name $TASK_NAME \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_lower_case \
|
||||
--data_dir $GLUE_DIR/$TASK_NAME \
|
||||
--bert_model bert-base-uncased \
|
||||
--max_seq_length 128 \
|
||||
--train_batch_size 32 \
|
||||
--learning_rate 2e-5 \
|
||||
--num_train_epochs 3.0 \
|
||||
--output_dir /tmp/$TASK_NAME/
|
||||
```
|
||||
|
||||
where task name can be one of CoLA, SST-2, MRPC, STS-B, QQP, MNLI, QNLI, RTE, WNLI.
|
||||
|
||||
The dev set results will be present within the text file 'eval_results.txt' in the specified output_dir. In case of MNLI, since there are two separate dev sets, matched and mismatched, there will be a separate output folder called '/tmp/MNLI-MM/' in addition to '/tmp/MNLI/'.
|
||||
|
||||
The code has not been tested with half-precision training with apex on any GLUE task apart from MRPC, MNLI, CoLA, SST-2. The following section provides details on how to run half-precision training with MRPC. With that being said, there shouldn't be any issues in running half-precision training with the remaining GLUE tasks as well, since the data processor for each task inherits from the base class DataProcessor.
|
||||
|
||||
#### MRPC
|
||||
|
||||
This example code fine-tunes BERT on the Microsoft Research Paraphrase
|
||||
|
|
|
@ -31,6 +31,10 @@ from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from scipy.stats import pearsonr, spearmanr
|
||||
from sklearn.metrics import matthews_corrcoef, f1_score
|
||||
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
@ -167,6 +171,16 @@ class MnliProcessor(DataProcessor):
|
|||
return examples
|
||||
|
||||
|
||||
class MnliMismatchedProcessor(MnliProcessor):
|
||||
"""Processor for the MultiNLI Mismatched data set (GLUE version)."""
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev_mismatched.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
|
||||
class ColaProcessor(DataProcessor):
|
||||
"""Processor for the CoLA data set (GLUE version)."""
|
||||
|
||||
|
@ -227,13 +241,181 @@ class Sst2Processor(DataProcessor):
|
|||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer):
|
||||
class StsbProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return [None]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[7]
|
||||
text_b = line[8]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class QqpProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
try:
|
||||
text_a = line[3]
|
||||
text_b = line[4]
|
||||
label = line[5]
|
||||
except IndexError:
|
||||
continue
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class QnliProcessor(DataProcessor):
|
||||
"""Processor for the STS-B data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")),
|
||||
"dev_matched")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class RteProcessor(DataProcessor):
|
||||
"""Processor for the RTE data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["entailment", "not_entailment"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
class WnliProcessor(DataProcessor):
|
||||
"""Processor for the WNLI data set (GLUE version)."""
|
||||
|
||||
def get_train_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
|
||||
|
||||
def get_dev_examples(self, data_dir):
|
||||
"""See base class."""
|
||||
return self._create_examples(
|
||||
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
|
||||
|
||||
def get_labels(self):
|
||||
"""See base class."""
|
||||
return ["0", "1"]
|
||||
|
||||
def _create_examples(self, lines, set_type):
|
||||
"""Creates examples for the training and dev sets."""
|
||||
examples = []
|
||||
for (i, line) in enumerate(lines):
|
||||
if i == 0:
|
||||
continue
|
||||
guid = "%s-%s" % (set_type, line[0])
|
||||
text_a = line[1]
|
||||
text_b = line[2]
|
||||
label = line[-1]
|
||||
examples.append(
|
||||
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
||||
return examples
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||
tokenizer, output_mode):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
label_map = {label : i for i, label in enumerate(label_list)}
|
||||
|
||||
features = []
|
||||
for (ex_index, example) in enumerate(examples):
|
||||
if ex_index % 10000 == 0:
|
||||
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
|
||||
|
||||
tokens_a = tokenizer.tokenize(example.text_a)
|
||||
|
||||
tokens_b = None
|
||||
|
@ -289,7 +471,13 @@ def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer
|
|||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
label_id = label_map[example.label]
|
||||
if output_mode == "classification":
|
||||
label_id = label_map[example.label]
|
||||
elif output_mode == "regression":
|
||||
label_id = float(example.label)
|
||||
else:
|
||||
raise KeyError(output_mode)
|
||||
|
||||
if ex_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("guid: %s" % (example.guid))
|
||||
|
@ -325,9 +513,56 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
|
|||
else:
|
||||
tokens_b.pop()
|
||||
|
||||
def accuracy(out, labels):
|
||||
outputs = np.argmax(out, axis=1)
|
||||
return np.sum(outputs == labels)
|
||||
|
||||
def simple_accuracy(preds, labels):
|
||||
return (preds == labels).mean()
|
||||
|
||||
|
||||
def acc_and_f1(preds, labels):
|
||||
acc = simple_accuracy(preds, labels)
|
||||
f1 = f1_score(y_true=labels, y_pred=preds)
|
||||
return {
|
||||
"acc": acc,
|
||||
"f1": f1,
|
||||
"acc_and_f1": (acc + f1) / 2,
|
||||
}
|
||||
|
||||
|
||||
def pearson_and_spearman(preds, labels):
|
||||
pearson_corr = pearsonr(preds, labels)[0]
|
||||
spearman_corr = spearmanr(preds, labels)[0]
|
||||
return {
|
||||
"pearson": pearson_corr,
|
||||
"spearmanr": spearman_corr,
|
||||
"corr": (pearson_corr + spearman_corr) / 2,
|
||||
}
|
||||
|
||||
|
||||
def compute_metrics(task_name, preds, labels):
|
||||
assert len(preds) == len(labels)
|
||||
if task_name == "cola":
|
||||
return {"mcc": matthews_corrcoef(labels, preds)}
|
||||
elif task_name == "sst-2":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mrpc":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "sts-b":
|
||||
return pearson_and_spearman(preds, labels)
|
||||
elif task_name == "qqp":
|
||||
return acc_and_f1(preds, labels)
|
||||
elif task_name == "mnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "mnli-mm":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "qnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "rte":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
elif task_name == "wnli":
|
||||
return {"acc": simple_accuracy(preds, labels)}
|
||||
else:
|
||||
raise KeyError(task_name)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
@ -431,15 +666,26 @@ def main():
|
|||
processors = {
|
||||
"cola": ColaProcessor,
|
||||
"mnli": MnliProcessor,
|
||||
"mnli-mm": MnliMismatchedProcessor,
|
||||
"mrpc": MrpcProcessor,
|
||||
"sst-2": Sst2Processor,
|
||||
"sts-b": StsbProcessor,
|
||||
"qqp": QqpProcessor,
|
||||
"qnli": QnliProcessor,
|
||||
"rte": RteProcessor,
|
||||
"wnli": WnliProcessor,
|
||||
}
|
||||
|
||||
num_labels_task = {
|
||||
"cola": 2,
|
||||
"sst-2": 2,
|
||||
"mnli": 3,
|
||||
"mrpc": 2,
|
||||
output_modes = {
|
||||
"cola": "classification",
|
||||
"mnli": "classification",
|
||||
"mrpc": "classification",
|
||||
"sst-2": "classification",
|
||||
"sts-b": "regression",
|
||||
"qqp": "classification",
|
||||
"qnli": "classification",
|
||||
"rte": "classification",
|
||||
"wnli": "classification",
|
||||
}
|
||||
|
||||
if args.local_rank == -1 or args.no_cuda:
|
||||
|
@ -480,8 +726,10 @@ def main():
|
|||
raise ValueError("Task not found: %s" % (task_name))
|
||||
|
||||
processor = processors[task_name]()
|
||||
num_labels = num_labels_task[task_name]
|
||||
output_mode = output_modes[task_name]
|
||||
|
||||
label_list = processor.get_labels()
|
||||
num_labels = len(label_list)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
|
@ -498,7 +746,7 @@ def main():
|
|||
cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||
cache_dir=cache_dir,
|
||||
num_labels = num_labels)
|
||||
num_labels=num_labels)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
model.to(device)
|
||||
|
@ -546,7 +794,7 @@ def main():
|
|||
tr_loss = 0
|
||||
if args.do_train:
|
||||
train_features = convert_examples_to_features(
|
||||
train_examples, label_list, args.max_seq_length, tokenizer)
|
||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_examples))
|
||||
logger.info(" Batch size = %d", args.train_batch_size)
|
||||
|
@ -554,7 +802,12 @@ def main():
|
|||
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)
|
||||
|
||||
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_data)
|
||||
|
@ -569,7 +822,17 @@ def main():
|
|||
for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
|
||||
batch = tuple(t.to(device) for t in batch)
|
||||
input_ids, input_mask, segment_ids, label_ids = batch
|
||||
loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
|
||||
# define a new function to compute loss values for both output_modes
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
elif output_mode == "regression":
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||
|
||||
if n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu.
|
||||
if args.gradient_accumulation_steps > 1:
|
||||
|
@ -613,22 +876,28 @@ def main():
|
|||
if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer)
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
if output_mode == "classification":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
elif output_mode == "regression":
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss, eval_accuracy = 0, 0
|
||||
nb_eval_steps, nb_eval_examples = 0, 0
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
|
@ -637,26 +906,36 @@ def main():
|
|||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
|
||||
logits = model(input_ids, segment_ids, input_mask)
|
||||
|
||||
logits = logits.detach().cpu().numpy()
|
||||
label_ids = label_ids.to('cpu').numpy()
|
||||
tmp_eval_accuracy = accuracy(logits, label_ids)
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
# create eval loss and other metric required by the task
|
||||
if output_mode == "classification":
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
elif output_mode == "regression":
|
||||
loss_fct = MSELoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
eval_accuracy += tmp_eval_accuracy
|
||||
|
||||
nb_eval_examples += input_ids.size(0)
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
eval_accuracy = eval_accuracy / nb_eval_examples
|
||||
preds = preds[0]
|
||||
if output_mode == "classification":
|
||||
preds = np.argmax(preds, axis=1)
|
||||
elif output_mode == "regression":
|
||||
preds = np.squeeze(preds)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
result = {'eval_loss': eval_loss,
|
||||
'eval_accuracy': eval_accuracy,
|
||||
'global_step': global_step,
|
||||
'loss': loss}
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
|
@ -665,5 +944,73 @@ def main():
|
|||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
# hack for MNLI-MM
|
||||
if task_name == "mnli":
|
||||
task_name = "mnli-mm"
|
||||
processor = processors[task_name]()
|
||||
|
||||
if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
if not os.path.exists(args.output_dir + '-MM'):
|
||||
os.makedirs(args.output_dir + '-MM')
|
||||
|
||||
eval_examples = processor.get_dev_examples(args.data_dir)
|
||||
eval_features = convert_examples_to_features(
|
||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
|
||||
logger.info("***** Running evaluation *****")
|
||||
logger.info(" Num examples = %d", len(eval_examples))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
|
||||
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
|
||||
|
||||
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
|
||||
# Run prediction for full data
|
||||
eval_sampler = SequentialSampler(eval_data)
|
||||
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
model.eval()
|
||||
eval_loss = 0
|
||||
nb_eval_steps = 0
|
||||
preds = []
|
||||
|
||||
for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
input_ids = input_ids.to(device)
|
||||
input_mask = input_mask.to(device)
|
||||
segment_ids = segment_ids.to(device)
|
||||
label_ids = label_ids.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids, segment_ids, input_mask, labels=None)
|
||||
|
||||
loss_fct = CrossEntropyLoss()
|
||||
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if len(preds) == 0:
|
||||
preds.append(logits.detach().cpu().numpy())
|
||||
else:
|
||||
preds[0] = np.append(
|
||||
preds[0], logits.detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
preds = preds[0]
|
||||
preds = np.argmax(preds, axis=1)
|
||||
result = compute_metrics(task_name, preds, all_label_ids.numpy())
|
||||
loss = tr_loss/nb_tr_steps if args.do_train else None
|
||||
|
||||
result['eval_loss'] = eval_loss
|
||||
result['global_step'] = global_step
|
||||
result['loss'] = loss
|
||||
|
||||
output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
|
||||
with open(output_eval_file, "w") as writer:
|
||||
logger.info("***** Eval results *****")
|
||||
for key in sorted(result.keys()):
|
||||
logger.info(" %s = %s", key, str(result[key]))
|
||||
writer.write("%s = %s\n" % (key, str(result[key])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue