diff --git a/run_squad.py b/run_squad.py index e90683ef73..fa7575de25 100644 --- a/run_squad.py +++ b/run_squad.py @@ -406,7 +406,7 @@ RawResult = collections.namedtuple("RawResult", def write_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, - output_nbest_file): + output_nbest_file, verbose_logging): """Write final predictions to the json file.""" logger.info("Writing predictions to: %s" % (output_prediction_file)) logger.info("Writing nbest to: %s" % (output_nbest_file)) @@ -492,7 +492,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, tok_text = " ".join(tok_text.split()) orig_text = " ".join(orig_tokens) - final_text = get_final_text(tok_text, orig_text, do_lower_case) + final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging) if final_text in seen_predictions: continue @@ -538,7 +538,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, writer.write(json.dumps(all_nbest_json, indent=4) + "\n") -def get_final_text(pred_text, orig_text, do_lower_case): +def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): """Project the tokenized prediction back to the original text.""" # When we created the data, we kept track of the alignment between original @@ -587,7 +587,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): start_position = tok_text.find(pred_text) if start_position == -1: - if args.verbose_logging: + if verbose_logging: logger.info( "Unable to find text: '%s' in '%s'" % (pred_text, orig_text)) return orig_text @@ -597,7 +597,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): (tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text) if len(orig_ns_text) != len(tok_ns_text): - if args.verbose_logging: + if verbose_logging: logger.info("Length not equal after stripping spaces: '%s' vs '%s'", orig_ns_text, tok_ns_text) return orig_text @@ -615,7 +615,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): orig_start_position = orig_ns_to_s_map[ns_start_position] if orig_start_position is None: - if args.verbose_logging: + if verbose_logging: logger.info("Couldn't map start position") return orig_text @@ -626,7 +626,7 @@ def get_final_text(pred_text, orig_text, do_lower_case): orig_end_position = orig_ns_to_s_map[ns_end_position] if orig_end_position is None: - if args.verbose_logging: + if verbose_logging: logger.info("Couldn't map end position") return orig_text @@ -949,7 +949,7 @@ def main(): write_predictions(eval_examples, eval_features, all_results, args.n_best_size, args.max_answer_length, args.do_lower_case, output_prediction_file, - output_nbest_file) + output_nbest_file, args.verbose_logging) if __name__ == "__main__":