fixing verbose_argument

This commit is contained in:
thomwolf 2018-11-04 09:53:29 +01:00
parent d6418c5ef3
commit 26bdef4321
1 changed files with 8 additions and 8 deletions

View File

@ -406,7 +406,7 @@ RawResult = collections.namedtuple("RawResult",
def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length, do_lower_case, output_prediction_file,
output_nbest_file):
output_nbest_file, verbose_logging):
"""Write final predictions to the json file."""
logger.info("Writing predictions to: %s" % (output_prediction_file))
logger.info("Writing nbest to: %s" % (output_nbest_file))
@ -492,7 +492,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
tok_text = " ".join(tok_text.split())
orig_text = " ".join(orig_tokens)
final_text = get_final_text(tok_text, orig_text, do_lower_case)
final_text = get_final_text(tok_text, orig_text, do_lower_case, verbose_logging)
if final_text in seen_predictions:
continue
@ -538,7 +538,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
def get_final_text(pred_text, orig_text, do_lower_case):
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
"""Project the tokenized prediction back to the original text."""
# When we created the data, we kept track of the alignment between original
@ -587,7 +587,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
start_position = tok_text.find(pred_text)
if start_position == -1:
if args.verbose_logging:
if verbose_logging:
logger.info(
"Unable to find text: '%s' in '%s'" % (pred_text, orig_text))
return orig_text
@ -597,7 +597,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
(tok_ns_text, tok_ns_to_s_map) = _strip_spaces(tok_text)
if len(orig_ns_text) != len(tok_ns_text):
if args.verbose_logging:
if verbose_logging:
logger.info("Length not equal after stripping spaces: '%s' vs '%s'",
orig_ns_text, tok_ns_text)
return orig_text
@ -615,7 +615,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
orig_start_position = orig_ns_to_s_map[ns_start_position]
if orig_start_position is None:
if args.verbose_logging:
if verbose_logging:
logger.info("Couldn't map start position")
return orig_text
@ -626,7 +626,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
orig_end_position = orig_ns_to_s_map[ns_end_position]
if orig_end_position is None:
if args.verbose_logging:
if verbose_logging:
logger.info("Couldn't map end position")
return orig_text
@ -949,7 +949,7 @@ def main():
write_predictions(eval_examples, eval_features, all_results,
args.n_best_size, args.max_answer_length,
args.do_lower_case, output_prediction_file,
output_nbest_file)
output_nbest_file, args.verbose_logging)
if __name__ == "__main__":