From 67403413bd8f8e00759a9cffe8608e092fa7b519 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Tue, 20 Sep 2022 06:17:57 -0700 Subject: [PATCH] Change document question answering pipeline to always return an array (#19071) Co-authored-by: Ankur Goyal --- src/transformers/pipelines/document_question_answering.py | 2 -- tests/pipelines/test_pipelines_document_question_answering.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index b0fe18cb9d..a6afc069fb 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -383,8 +383,6 @@ class DocumentQuestionAnsweringPipeline(Pipeline): answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] - if len(answers) == 1: - return answers[0] return answers def postprocess_donut(self, model_outputs, **kwargs): diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py index 091f6c3c03..92d618bfd6 100644 --- a/tests/pipelines/test_pipelines_document_question_answering.py +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -267,7 +267,7 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli image = INVOICE_URL question = "What is the invoice number?" outputs = dqa_pipeline(image=image, question=question, top_k=2) - self.assertEqual(nested_simplify(outputs, decimals=4), {"answer": "us-001"}) + self.assertEqual(nested_simplify(outputs, decimals=4), [{"answer": "us-001"}]) @require_tf @unittest.skip("Document question answering not implemented in TF")