From 40dbda6871263067e3cf2030a1e9aaffef7837e5 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Tue, 18 Jun 2019 16:45:52 +0200 Subject: [PATCH] updating classification example --- examples/run_classifier.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/run_classifier.py b/examples/run_classifier.py index 49fb3954b3..0add05113f 100644 --- a/examples/run_classifier.py +++ b/examples/run_classifier.py @@ -228,10 +228,10 @@ def main(): # Prepare data loader train_examples = processor.get_train_examples(args.data_dir) - cached_train_features_file = args.data_dir + '_{0}_{1}_{2}'.format( + cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format( list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), - str(task_name)) + str(task_name))) try: with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) @@ -311,7 +311,7 @@ def main(): input_ids, input_mask, segment_ids, label_ids = batch # define a new function to compute loss values for both output_modes - logits = model(input_ids, segment_ids, input_mask, labels=None) + logits = model(input_ids, segment_ids, input_mask) if output_mode == "classification": loss_fct = CrossEntropyLoss() @@ -380,6 +380,22 @@ def main(): ### Evaluation if args.do_eval: eval_examples = processor.get_dev_examples(args.data_dir) + cached_train_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format( + list(filter(None, args.bert_model.split('/'))).pop(), + str(args.max_seq_length), + str(task_name))) + try: + with open(cached_train_features_file, "rb") as reader: + train_features = pickle.load(reader) + except: + train_features = convert_examples_to_features( + train_examples, label_list, args.max_seq_length, tokenizer, output_mode) + if args.local_rank == -1 or torch.distributed.get_rank() == 0: + logger.info(" Saving train features into cached file %s", cached_train_features_file) + with open(cached_train_features_file, "wb") as writer: + pickle.dump(train_features, writer) + + eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) logger.info("***** Running evaluation *****") @@ -414,7 +430,7 @@ def main(): label_ids = label_ids.to(device) with torch.no_grad(): - logits = model(input_ids, segment_ids, input_mask, labels=None) + logits = model(input_ids, segment_ids, input_mask) # create eval loss and other metric required by the task if output_mode == "classification":