From 2ef774211733f0acf8d3415f9284c49ef219e991 Mon Sep 17 00:00:00 2001 From: Ankur Goyal Date: Wed, 7 Sep 2022 10:38:49 -0700 Subject: [PATCH] Add DocumentQuestionAnswering pipeline (#18414) * [WIP] Skeleton of VisualQuestionAnweringPipeline extended to support LayoutLM-like models * Fixup * Use the full encoding * Basic refactoring to DocumentQuestionAnsweringPipeline * Cleanup * Improve args, docs, and implement preprocessing * Integrate OCR * Refactor question_answering pipeline * Use refactored QA code in the document qa pipeline * Fix tests * Some small cleanups * Use a string type annotation for Image.Image * Update encoding with image features * Wire through the basic docs * Handle invalid response * Handle empty word_boxes properly * Docstring fix * Integrate Donut model * Fixup * Incorporate comments * Address comments * Initial incorporation of tests * Address Comments * Change assert to ValueError * Comments * Wrap `score` in float to make it JSON serializable * Incorporate AutoModeLForDocumentQuestionAnswering changes * Fixup * Rename postprocess function * Fix auto import * Applying comments * Improve docs * Remove extra assets and add copyright * Address comments Co-authored-by: Ankur Goyal --- docs/source/en/main_classes/pipelines.mdx | 7 + docs/source/en/model_doc/auto.mdx | 8 + src/transformers/__init__.py | 10 + src/transformers/models/auto/__init__.py | 8 + src/transformers/models/auto/modeling_auto.py | 22 + .../models/auto/modeling_tf_auto.py | 21 + src/transformers/pipelines/__init__.py | 13 +- src/transformers/pipelines/base.py | 4 +- .../pipelines/document_question_answering.py | 443 ++++++++++++++++++ .../pipelines/question_answering.py | 188 +++++--- src/transformers/utils/dummy_pt_objects.py | 10 + src/transformers/utils/dummy_tf_objects.py | 10 + src/transformers/utils/fx.py | 5 +- .../models/layoutlm/test_modeling_layoutlm.py | 30 -- .../layoutlm/test_modeling_tf_layoutlm.py | 25 - ...t_pipelines_document_question_answering.py | 280 +++++++++++ tests/test_modeling_common.py | 11 +- tests/test_modeling_tf_common.py | 6 +- 18 files changed, 962 insertions(+), 139 deletions(-) create mode 100644 src/transformers/pipelines/document_question_answering.py create mode 100644 tests/pipelines/test_pipelines_document_question_answering.py diff --git a/docs/source/en/main_classes/pipelines.mdx b/docs/source/en/main_classes/pipelines.mdx index b2de7e048d..4043a00009 100644 --- a/docs/source/en/main_classes/pipelines.mdx +++ b/docs/source/en/main_classes/pipelines.mdx @@ -25,6 +25,7 @@ There are two categories of pipeline abstractions to be aware about: - [`AudioClassificationPipeline`] - [`AutomaticSpeechRecognitionPipeline`] - [`ConversationalPipeline`] + - [`DocumentQuestionAnsweringPipeline`] - [`FeatureExtractionPipeline`] - [`FillMaskPipeline`] - [`ImageClassificationPipeline`] @@ -342,6 +343,12 @@ That should enable you to do all the custom code you want. - __call__ - all +### DocumentQuestionAnsweringPipeline + +[[autodoc]] DocumentQuestionAnsweringPipeline + - __call__ + - all + ### FeatureExtractionPipeline [[autodoc]] FeatureExtractionPipeline diff --git a/docs/source/en/model_doc/auto.mdx b/docs/source/en/model_doc/auto.mdx index 995296485b..93976424ba 100644 --- a/docs/source/en/model_doc/auto.mdx +++ b/docs/source/en/model_doc/auto.mdx @@ -114,6 +114,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] AutoModelForTableQuestionAnswering +## AutoModelForDocumentQuestionAnswering + +[[autodoc]] AutoModelForDocumentQuestionAnswering + ## AutoModelForImageClassification [[autodoc]] AutoModelForImageClassification @@ -214,6 +218,10 @@ Likewise, if your `NewModel` is a subclass of [`PreTrainedModel`], make sure its [[autodoc]] TFAutoModelForTableQuestionAnswering +## TFAutoModelForDocumentQuestionAnswering + +[[autodoc]] TFAutoModelForDocumentQuestionAnswering + ## TFAutoModelForTokenClassification [[autodoc]] TFAutoModelForTokenClassification diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 4651c3b5b9..e10e2ce0ba 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -383,6 +383,7 @@ _import_structure = { "Conversation", "ConversationalPipeline", "CsvPipelineDataFormat", + "DocumentQuestionAnsweringPipeline", "FeatureExtractionPipeline", "FillMaskPipeline", "ImageClassificationPipeline", @@ -789,6 +790,7 @@ else: "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", @@ -816,6 +818,7 @@ else: "AutoModelForAudioXVector", "AutoModelForCausalLM", "AutoModelForCTC", + "AutoModelForDocumentQuestionAnswering", "AutoModelForImageClassification", "AutoModelForImageSegmentation", "AutoModelForInstanceSegmentation", @@ -2107,6 +2110,7 @@ else: "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_PRETRAINING_MAPPING", + "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", @@ -2124,6 +2128,7 @@ else: "TFAutoModelForMultipleChoice", "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", + "TFAutoModelForDocumentQuestionAnswering", "TFAutoModelForQuestionAnswering", "TFAutoModelForSemanticSegmentation", "TFAutoModelForSeq2SeqLM", @@ -3200,6 +3205,7 @@ if TYPE_CHECKING: Conversation, ConversationalPipeline, CsvPipelineDataFormat, + DocumentQuestionAnsweringPipeline, FeatureExtractionPipeline, FillMaskPipeline, ImageClassificationPipeline, @@ -3549,6 +3555,7 @@ if TYPE_CHECKING: MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, @@ -3576,6 +3583,7 @@ if TYPE_CHECKING: AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForInstanceSegmentation, @@ -4637,6 +4645,7 @@ if TYPE_CHECKING: ) from .models.auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, @@ -4655,6 +4664,7 @@ if TYPE_CHECKING: TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForDocumentQuestionAnswering, TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, diff --git a/src/transformers/models/auto/__init__.py b/src/transformers/models/auto/__init__.py index ec253f6037..6129253f14 100644 --- a/src/transformers/models/auto/__init__.py +++ b/src/transformers/models/auto/__init__.py @@ -47,6 +47,7 @@ else: "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", "MODEL_FOR_CAUSAL_LM_MAPPING", "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", @@ -93,6 +94,7 @@ else: "AutoModelForVideoClassification", "AutoModelForVision2Seq", "AutoModelForVisualQuestionAnswering", + "AutoModelForDocumentQuestionAnswering", "AutoModelWithLMHead", ] @@ -111,6 +113,7 @@ else: "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", "TF_MODEL_FOR_PRETRAINING_MAPPING", "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", "TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", @@ -127,6 +130,7 @@ else: "TFAutoModelForMultipleChoice", "TFAutoModelForNextSentencePrediction", "TFAutoModelForPreTraining", + "TFAutoModelForDocumentQuestionAnswering", "TFAutoModelForQuestionAnswering", "TFAutoModelForSemanticSegmentation", "TFAutoModelForSeq2SeqLM", @@ -191,6 +195,7 @@ if TYPE_CHECKING: MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, MODEL_FOR_CTC_MAPPING, + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, @@ -218,6 +223,7 @@ if TYPE_CHECKING: AutoModelForAudioXVector, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForInstanceSegmentation, @@ -248,6 +254,7 @@ if TYPE_CHECKING: else: from .modeling_tf_auto import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, @@ -266,6 +273,7 @@ if TYPE_CHECKING: TF_MODEL_WITH_LM_HEAD_MAPPING, TFAutoModel, TFAutoModelForCausalLM, + TFAutoModelForDocumentQuestionAnswering, TFAutoModelForImageClassification, TFAutoModelForMaskedLM, TFAutoModelForMultipleChoice, diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 5060b535b0..1cb0ae44db 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -603,6 +603,14 @@ MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ] ) +MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("layoutlm", "LayoutLMForQuestionAnswering"), + ("layoutlmv2", "LayoutLMv2ForQuestionAnswering"), + ("layoutlmv3", "LayoutLMv3ForQuestionAnswering"), + ] +) + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict( [ # Model for Token Classification mapping @@ -773,6 +781,9 @@ MODEL_FOR_VISION_2_SEQ_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FO MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING_NAMES ) +MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES +) MODEL_FOR_MASKED_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES) MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES @@ -891,6 +902,17 @@ AutoModelForVisualQuestionAnswering = auto_class_update( ) +class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + +AutoModelForDocumentQuestionAnswering = auto_class_update( + AutoModelForDocumentQuestionAnswering, + head_doc="document question answering", + checkpoint_for_example='impira/layoutlm-document-qa", revision="3dc6de3', +) + + class AutoModelForTokenClassification(_BaseAutoModelClass): _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index a12f6accdc..ba1e74e14c 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -315,6 +315,13 @@ TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( ] ) +TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( + [ + ("layoutlm", "TFLayoutLMForQuestionAnswering"), + ] +) + + TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict( [ # Model for Table Question Answering mapping @@ -406,6 +413,9 @@ TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING = _LazyAutoMapping( TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES ) +TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( + CONFIG_MAPPING_NAMES, TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES +) TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING = _LazyAutoMapping( CONFIG_MAPPING_NAMES, TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES ) @@ -515,6 +525,17 @@ class TFAutoModelForQuestionAnswering(_BaseAutoModelClass): TFAutoModelForQuestionAnswering = auto_class_update(TFAutoModelForQuestionAnswering, head_doc="question answering") +class TFAutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): + _model_mapping = TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + +TFAutoModelForDocumentQuestionAnswering = auto_class_update( + TFAutoModelForDocumentQuestionAnswering, + head_doc="document question answering", + checkpoint_for_example='impira/layoutlm-document-qa", revision="3dc6de3', +) + + class TFAutoModelForTableQuestionAnswering(_BaseAutoModelClass): _model_mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING diff --git a/src/transformers/pipelines/__init__.py b/src/transformers/pipelines/__init__.py index ee7dee57c0..e3f9e603b5 100755 --- a/src/transformers/pipelines/__init__.py +++ b/src/transformers/pipelines/__init__.py @@ -51,6 +51,7 @@ from .base import ( infer_framework_load_model, ) from .conversational import Conversation, ConversationalPipeline +from .document_question_answering import DocumentQuestionAnsweringPipeline from .feature_extraction import FeatureExtractionPipeline from .fill_mask import FillMaskPipeline from .image_classification import ImageClassificationPipeline @@ -109,6 +110,7 @@ if is_torch_available(): AutoModelForAudioClassification, AutoModelForCausalLM, AutoModelForCTC, + AutoModelForDocumentQuestionAnswering, AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForMaskedLM, @@ -215,6 +217,15 @@ SUPPORTED_TASKS = { }, "type": "multimodal", }, + "document-question-answering": { + "impl": DocumentQuestionAnsweringPipeline, + "pt": (AutoModelForDocumentQuestionAnswering,) if is_torch_available() else (), + "tf": (), + "default": { + "model": {"pt": ("impira/layoutlm-document-qa", "3a93017")}, + }, + "type": "multimodal", + }, "fill-mask": { "impl": FillMaskPipeline, "tf": (TFAutoModelForMaskedLM,) if is_tf_available() else (), @@ -443,7 +454,7 @@ def pipeline( trust_remote_code: Optional[bool] = None, model_kwargs: Dict[str, Any] = None, pipeline_class: Optional[Any] = None, - **kwargs + **kwargs, ) -> Pipeline: """ Utility factory method to build a [`Pipeline`]. diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 7842b95b32..b5e7c9cb58 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -178,7 +178,7 @@ def infer_framework_load_model( model_classes: Optional[Dict[str, Tuple[type]]] = None, task: Optional[str] = None, framework: Optional[str] = None, - **model_kwargs + **model_kwargs, ): """ Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model). @@ -274,7 +274,7 @@ def infer_framework_from_model( model_classes: Optional[Dict[str, Tuple[type]]] = None, task: Optional[str] = None, framework: Optional[str] = None, - **model_kwargs + **model_kwargs, ): """ Select framework (TensorFlow or PyTorch) to use from the `model` passed. Returns a tuple (framework, model). diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py new file mode 100644 index 0000000000..b0fe18cb9d --- /dev/null +++ b/src/transformers/pipelines/document_question_answering.py @@ -0,0 +1,443 @@ +# Copyright 2022 The Impira Team and the HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List, Optional, Tuple, Union + +import numpy as np + +from ..utils import ( + ExplicitEnum, + add_end_docstrings, + is_pytesseract_available, + is_torch_available, + is_vision_available, + logging, +) +from .base import PIPELINE_INIT_ARGS, Pipeline +from .question_answering import select_starts_ends + + +if is_vision_available(): + from PIL import Image + + from ..image_utils import load_image + +if is_torch_available(): + import torch + + from ..models.auto.modeling_auto import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + +TESSERACT_LOADED = False +if is_pytesseract_available(): + TESSERACT_LOADED = True + import pytesseract + +logger = logging.get_logger(__name__) + + +# normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py. +# However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an +# unecessary dependency. +def normalize_box(box, width, height): + return [ + int(1000 * (box[0] / width)), + int(1000 * (box[1] / height)), + int(1000 * (box[2] / width)), + int(1000 * (box[3] / height)), + ] + + +def apply_tesseract(image: "Image.Image", lang: Optional[str], tesseract_config: Optional[str]): + """Applies Tesseract OCR on a document image, and returns recognized words + normalized bounding boxes.""" + # apply OCR + data = pytesseract.image_to_data(image, lang=lang, output_type="dict", config=tesseract_config) + words, left, top, width, height = data["text"], data["left"], data["top"], data["width"], data["height"] + + # filter empty words and corresponding coordinates + irrelevant_indices = [idx for idx, word in enumerate(words) if not word.strip()] + words = [word for idx, word in enumerate(words) if idx not in irrelevant_indices] + left = [coord for idx, coord in enumerate(left) if idx not in irrelevant_indices] + top = [coord for idx, coord in enumerate(top) if idx not in irrelevant_indices] + width = [coord for idx, coord in enumerate(width) if idx not in irrelevant_indices] + height = [coord for idx, coord in enumerate(height) if idx not in irrelevant_indices] + + # turn coordinates into (left, top, left+width, top+height) format + actual_boxes = [] + for x, y, w, h in zip(left, top, width, height): + actual_box = [x, y, x + w, y + h] + actual_boxes.append(actual_box) + + image_width, image_height = image.size + + # finally, normalize the bounding boxes + normalized_boxes = [] + for box in actual_boxes: + normalized_boxes.append(normalize_box(box, image_width, image_height)) + + if len(words) != len(normalized_boxes): + raise ValueError("Not as many words as there are bounding boxes") + + return words, normalized_boxes + + +class ModelType(ExplicitEnum): + LayoutLM = "layoutlm" + LayoutLMv2andv3 = "layoutlmv2andv3" + VisionEncoderDecoder = "vision_encoder_decoder" + + +@add_end_docstrings(PIPELINE_INIT_ARGS) +class DocumentQuestionAnsweringPipeline(Pipeline): + # TODO: Update task_summary docs to include an example with document QA and then update the first sentence + """ + Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are + similar to the (extractive) question answering pipeline; however, the pipeline takes an image (and optional OCR'd + words/boxes) as input instead of text context. + + This document question answering pipeline can currently be loaded from [`pipeline`] using the following task + identifier: `"document-question-answering"`. + + The models that this pipeline can use are models that have been fine-tuned on a document question answering task. + See the up-to-date list of available models on + [huggingface.co/models](https://huggingface.co/models?filter=document-question-answering). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.check_model_type(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING) + + if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig": + self.model_type = ModelType.VisionEncoderDecoder + if self.model.config.encoder.model_type != "donut-swin": + raise ValueError("Currently, the only supported VisionEncoderDecoder model is Donut") + elif self.model.config.__class__.__name__ == "LayoutLMConfig": + self.model_type = ModelType.LayoutLM + else: + self.model_type = ModelType.LayoutLMv2andv3 + + def _sanitize_parameters( + self, + padding=None, + doc_stride=None, + max_question_len=None, + lang: Optional[str] = None, + tesseract_config: Optional[str] = None, + max_answer_len=None, + max_seq_len=None, + top_k=None, + handle_impossible_answer=None, + **kwargs, + ): + preprocess_params, postprocess_params = {}, {} + if padding is not None: + preprocess_params["padding"] = padding + if doc_stride is not None: + preprocess_params["doc_stride"] = doc_stride + if max_question_len is not None: + preprocess_params["max_question_len"] = max_question_len + if max_seq_len is not None: + preprocess_params["max_seq_len"] = max_seq_len + if lang is not None: + preprocess_params["lang"] = lang + if tesseract_config is not None: + preprocess_params["tesseract_config"] = tesseract_config + + if top_k is not None: + if top_k < 1: + raise ValueError(f"top_k parameter should be >= 1 (got {top_k})") + postprocess_params["top_k"] = top_k + if max_answer_len is not None: + if max_answer_len < 1: + raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}") + postprocess_params["max_answer_len"] = max_answer_len + if handle_impossible_answer is not None: + postprocess_params["handle_impossible_answer"] = handle_impossible_answer + + return preprocess_params, {}, postprocess_params + + def __call__( + self, + image: Union["Image.Image", str], + question: Optional[str] = None, + word_boxes: Tuple[str, List[float]] = None, + **kwargs, + ): + """ + Answer the question(s) given as inputs by using the document(s). A document is defined as an image and an + optional list of (word, box) tuples which represent the text in the document. If the `word_boxes` are not + provided, it will use the Tesseract OCR engine (if available) to extract the words and boxes automatically for + LayoutLM-like models which require them as input. For Donut, no OCR is run. + + You can invoke the pipeline several ways: + + - `pipeline(image=image, question=question)` + - `pipeline(image=image, question=question, word_boxes=word_boxes)` + - `pipeline([{"image": image, "question": question}])` + - `pipeline([{"image": image, "question": question, "word_boxes": word_boxes}])` + + Args: + image (`str` or `PIL.Image`): + The pipeline handles three types of images: + + - A string containing a http link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL directly + + The pipeline accepts either a single image or a batch of images. If given a single image, it can be + broadcasted to multiple questions. + question (`str`): + A question to ask of the document. + word_boxes (`List[str, Tuple[float, float, float, float]]`, *optional*): + A list of words and bounding boxes (normalized 0->1000). If you provide this optional input, then the + pipeline will use these words and boxes instead of running OCR on the image to derive them for models + that need them (e.g. LayoutLM). This allows you to reuse OCR'd results across many invocations of the + pipeline without having to re-run it each time. + top_k (`int`, *optional*, defaults to 1): + The number of answers to return (will be chosen by order of likelihood). Note that we return less than + top_k answers if there are not enough options available within the context. + doc_stride (`int`, *optional*, defaults to 128): + If the words in the document are too long to fit with the question for the model, it will be split in + several chunks with some overlap. This argument controls the size of that overlap. + max_answer_len (`int`, *optional*, defaults to 15): + The maximum length of predicted answers (e.g., only answers with a shorter length are considered). + max_seq_len (`int`, *optional*, defaults to 384): + The maximum length of the total sentence (context + question) in tokens of each chunk passed to the + model. The context will be split in several chunks (using `doc_stride` as overlap) if needed. + max_question_len (`int`, *optional*, defaults to 64): + The maximum length of the question after tokenization. It will be truncated if needed. + handle_impossible_answer (`bool`, *optional*, defaults to `False`): + Whether or not we accept impossible as an answer. + lang (`str`, *optional*): + Language to use while running OCR. Defaults to english. + tesseract_config (`str`, *optional*): + Additional flags to pass to tesseract while running OCR. + + Return: + A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys: + + - **score** (`float`) -- The probability associated to the answer. + - **start** (`int`) -- The start word index of the answer (in the OCR'd version of the input or provided + `word_boxes`). + - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided + `word_boxes`). + - **answer** (`str`) -- The answer to the question. + """ + if isinstance(question, str): + inputs = {"question": question, "image": image} + if word_boxes is not None: + inputs["word_boxes"] = word_boxes + else: + inputs = image + return super().__call__(inputs, **kwargs) + + def preprocess(self, input, lang=None, tesseract_config=""): + image = None + image_features = {} + if input.get("image", None) is not None: + image = load_image(input["image"]) + if self.feature_extractor is not None: + image_features.update(self.feature_extractor(images=image, return_tensors=self.framework)) + elif self.model_type == ModelType.VisionEncoderDecoder: + raise ValueError("If you are using a VisionEncoderDecoderModel, you must provide a feature extractor") + + words, boxes = None, None + if not self.model_type == ModelType.VisionEncoderDecoder: + if "word_boxes" in input: + words = [x[0] for x in input["word_boxes"]] + boxes = [x[1] for x in input["word_boxes"]] + elif "words" in image_features and "boxes" in image_features: + words = image_features.pop("words")[0] + boxes = image_features.pop("boxes")[0] + elif image is not None: + if not TESSERACT_LOADED: + raise ValueError( + "If you provide an image without word_boxes, then the pipeline will run OCR using Tesseract," + " but pytesseract is not available" + ) + if TESSERACT_LOADED: + words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config) + else: + raise ValueError( + "You must provide an image or word_boxes. If you provide an image, the pipeline will automatically" + " run OCR to derive words and boxes" + ) + + if self.tokenizer.padding_side != "right": + raise ValueError( + "Document question answering only supports tokenizers whose padding side is 'right', not" + f" {self.tokenizer.padding_side}" + ) + + if self.model_type == ModelType.VisionEncoderDecoder: + task_prompt = f'{input["question"]}' + # Adapted from https://huggingface.co/spaces/nielsr/donut-docvqa/blob/main/app.py + encoding = { + "inputs": image_features["pixel_values"], + "decoder_input_ids": self.tokenizer( + task_prompt, add_special_tokens=False, return_tensors=self.framework + ).input_ids, + "return_dict_in_generate": True, + } + p_mask = None + word_ids = None + words = None + else: + tokenizer_kwargs = {} + if self.model_type == ModelType.LayoutLM: + tokenizer_kwargs["text"] = input["question"].split() + tokenizer_kwargs["text_pair"] = words + tokenizer_kwargs["is_split_into_words"] = True + else: + tokenizer_kwargs["text"] = [input["question"]] + tokenizer_kwargs["text_pair"] = [words] + tokenizer_kwargs["boxes"] = [boxes] + + encoding = self.tokenizer( + return_token_type_ids=True, + return_tensors=self.framework, + # TODO: In a future PR, use these feature to handle sequences whose length is longer than + # the maximum allowed by the model. Currently, the tokenizer will produce a sequence that + # may be too long for the model to handle. + # truncation="only_second", + # return_overflowing_tokens=True, + **tokenizer_kwargs, + ) + + if "pixel_values" in image_features: + encoding["image"] = image_features.pop("pixel_values") + + # TODO: For now, this should always be num_spans == 1 given the flags we've passed in above, but the + # code is written to naturally handle multiple spans at the right time. + num_spans = len(encoding["input_ids"]) + + # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) + # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens) + # This logic mirrors the logic in the question_answering pipeline + p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)] + for span_idx in range(num_spans): + input_ids_span_idx = encoding["input_ids"][span_idx] + # keep the cls_token unmasked (some models use it to indicate unanswerable questions) + if self.tokenizer.cls_token_id is not None: + cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0] + for cls_index in cls_indices: + p_mask[span_idx][cls_index] = 0 + + # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] + # for SEP tokens, and the word's bounding box for words in the original document. + if "boxes" not in tokenizer_kwargs: + bbox = [] + for batch_index in range(num_spans): + for input_id, sequence_id, word_id in zip( + encoding.input_ids[batch_index], + encoding.sequence_ids(batch_index), + encoding.word_ids(batch_index), + ): + if sequence_id == 1: + bbox.append(boxes[word_id]) + elif input_id == self.tokenizer.sep_token_id: + bbox.append([1000] * 4) + else: + bbox.append([0] * 4) + + if self.framework == "tf": + raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") + elif self.framework == "pt": + encoding["bbox"] = torch.tensor([bbox]) + + word_ids = [encoding.word_ids(i) for i in range(num_spans)] + + return {**encoding, "p_mask": p_mask, "word_ids": word_ids, "words": words} + + def _forward(self, model_inputs): + p_mask = model_inputs.pop("p_mask", None) + word_ids = model_inputs.pop("word_ids", None) + words = model_inputs.pop("words", None) + + if self.model_type == ModelType.VisionEncoderDecoder: + model_outputs = self.model.generate(**model_inputs) + else: + model_outputs = self.model(**model_inputs) + + model_outputs["p_mask"] = p_mask + model_outputs["word_ids"] = word_ids + model_outputs["words"] = words + model_outputs["attention_mask"] = model_inputs.get("attention_mask", None) + return model_outputs + + def postprocess(self, model_outputs, top_k=1, **kwargs): + if self.model_type == ModelType.VisionEncoderDecoder: + answers = self.postprocess_donut(model_outputs) + else: + answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) + + answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] + if len(answers) == 1: + return answers[0] + return answers + + def postprocess_donut(self, model_outputs, **kwargs): + sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] + + # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer + # (see https://github.com/huggingface/transformers/pull/18414/files#r961747408 for more context). + sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "") + sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token + ret = { + "answer": None, + } + + answer = re.search(r"(.*)", sequence) + if answer is not None: + ret["answer"] = answer.group(1).strip() + return [ret] + + def postprocess_extractive_qa( + self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs + ): + min_null_score = 1000000 # large and positive + answers = [] + words = model_outputs["words"] + + # TODO: Currently, we expect the length of model_outputs to be 1, because we do not stride + # in the preprocessor code. When we implement that, we'll either need to handle tensors of size + # > 1 or use the ChunkPipeline and handle multiple outputs (each of size = 1). + starts, ends, scores, min_null_score = select_starts_ends( + model_outputs["start_logits"], + model_outputs["end_logits"], + model_outputs["p_mask"], + model_outputs["attention_mask"].numpy() if model_outputs.get("attention_mask", None) is not None else None, + min_null_score, + top_k, + handle_impossible_answer, + max_answer_len, + ) + + word_ids = model_outputs["word_ids"][0] + for start, eend, score in zip(starts, ends, scores): + word_start, word_end = word_ids[start], word_ids[eend] + if word_start is not None and word_end is not None: + answers.append( + { + "score": float(score), # XXX Write a test that verifies the result is JSON-serializable + "answer": " ".join(words[word_start : word_end + 1]), + "start": word_start, + "end": word_end, + } + ) + + if handle_impossible_answer: + answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0}) + + return answers diff --git a/src/transformers/pipelines/question_answering.py b/src/transformers/pipelines/question_answering.py index 6f07382dc5..6a1a0011c5 100644 --- a/src/transformers/pipelines/question_answering.py +++ b/src/transformers/pipelines/question_answering.py @@ -42,6 +42,110 @@ if is_torch_available(): from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING +def decode_spans( + start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray +) -> Tuple: + """ + Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the actual + answer. + + In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or + answer end position being before the starting position. The method supports output the k-best answer through the + topk argument. + + Args: + start (`np.ndarray`): Individual start probabilities for each token. + end (`np.ndarray`): Individual end probabilities for each token. + topk (`int`): Indicates how many possible answer span(s) to extract from the model output. + max_answer_len (`int`): Maximum size of the answer to extract from the model's output. + undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer + """ + # Ensure we have batch axis + if start.ndim == 1: + start = start[None] + + if end.ndim == 1: + end = end[None] + + # Compute the score of each tuple(start, end) to be the real answer + outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1)) + + # Remove candidate with end < start and end - start > max_answer_len + candidates = np.tril(np.triu(outer), max_answer_len - 1) + + # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) + scores_flat = candidates.flatten() + if topk == 1: + idx_sort = [np.argmax(scores_flat)] + elif len(scores_flat) < topk: + idx_sort = np.argsort(-scores_flat) + else: + idx = np.argpartition(-scores_flat, topk)[0:topk] + idx_sort = idx[np.argsort(-scores_flat[idx])] + + starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:] + desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(ends, undesired_tokens.nonzero()) + starts = starts[desired_spans] + ends = ends[desired_spans] + scores = candidates[0, starts, ends] + + return starts, ends, scores + + +def select_starts_ends( + start, + end, + p_mask, + attention_mask, + min_null_score=1000000, + top_k=1, + handle_impossible_answer=False, + max_answer_len=15, +): + """ + Takes the raw output of any `ModelForQuestionAnswering` and first normalizes its outputs and then uses + `decode_spans()` to generate probabilities for each span to be the actual answer. + + Args: + start (`np.ndarray`): Individual start logits for each token. + end (`np.ndarray`): Individual end logits for each token. + p_mask (`np.ndarray`): A mask with 1 for values that cannot be in the answer + attention_mask (`np.ndarray`): The attention mask generated by the tokenizer + min_null_score(`float`): The minimum null (empty) answer score seen so far. + topk (`int`): Indicates how many possible answer span(s) to extract from the model output. + handle_impossible_answer(`bool`): Whether to allow null (empty) answers + max_answer_len (`int`): Maximum size of the answer to extract from the model's output. + """ + # Ensure padded tokens & question tokens cannot belong to the set of candidate answers. + undesired_tokens = np.abs(np.array(p_mask) - 1) + + if attention_mask is not None: + undesired_tokens = undesired_tokens & attention_mask + + # Generate mask + undesired_tokens_mask = undesired_tokens == 0.0 + + # Make sure non-context indexes in the tensor cannot contribute to the softmax + start = np.where(undesired_tokens_mask, -10000.0, start) + end = np.where(undesired_tokens_mask, -10000.0, end) + + # Normalize logits and spans to retrieve the answer + start = np.exp(start - start.max(axis=-1, keepdims=True)) + start = start / start.sum() + + end = np.exp(end - end.max(axis=-1, keepdims=True)) + end = end / end.sum() + + if handle_impossible_answer: + min_null_score = min(min_null_score, (start[0, 0] * end[0, 0]).item()) + + # Mask CLS + start[0, 0] = end[0, 0] = 0.0 + + starts, ends, scores = decode_spans(start, end, top_k, max_answer_len, undesired_tokens) + return starts, ends, scores, min_null_score + + class QuestionAnsweringArgumentHandler(ArgumentHandler): """ QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped to @@ -141,7 +245,7 @@ class QuestionAnsweringPipeline(ChunkPipeline): framework: Optional[str] = None, device: int = -1, task: str = "", - **kwargs + **kwargs, ): super().__init__( model=model, @@ -410,34 +514,15 @@ class QuestionAnsweringPipeline(ChunkPipeline): start_ = output["start"] end_ = output["end"] example = output["example"] + p_mask = output["p_mask"] + attention_mask = ( + output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None + ) - # Ensure padded tokens & question tokens cannot belong to the set of candidate answers. - undesired_tokens = np.abs(np.array(output["p_mask"]) - 1) + starts, ends, scores, min_null_score = select_starts_ends( + start_, end_, p_mask, attention_mask, min_null_score, top_k, handle_impossible_answer, max_answer_len + ) - if output.get("attention_mask", None) is not None: - undesired_tokens = undesired_tokens & output["attention_mask"].numpy() - - # Generate mask - undesired_tokens_mask = undesired_tokens == 0.0 - - # Make sure non-context indexes in the tensor cannot contribute to the softmax - start_ = np.where(undesired_tokens_mask, -10000.0, start_) - end_ = np.where(undesired_tokens_mask, -10000.0, end_) - - # Normalize logits and spans to retrieve the answer - start_ = np.exp(start_ - start_.max(axis=-1, keepdims=True)) - start_ = start_ / start_.sum() - - end_ = np.exp(end_ - end_.max(axis=-1, keepdims=True)) - end_ = end_ / end_.sum() - - if handle_impossible_answer: - min_null_score = min(min_null_score, (start_[0, 0] * end_[0, 0]).item()) - - # Mask CLS - start_[0, 0] = end_[0, 0] = 0.0 - - starts, ends, scores = self.decode(start_, end_, top_k, max_answer_len, undesired_tokens) if not self.tokenizer.is_fast: char_to_word = np.array(example.char_to_word_offset) @@ -518,55 +603,6 @@ class QuestionAnsweringPipeline(ChunkPipeline): end_index = enc.offsets[e][1] return start_index, end_index - def decode( - self, start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray - ) -> Tuple: - """ - Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the - actual answer. - - In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or - answer end position being before the starting position. The method supports output the k-best answer through - the topk argument. - - Args: - start (`np.ndarray`): Individual start probabilities for each token. - end (`np.ndarray`): Individual end probabilities for each token. - topk (`int`): Indicates how many possible answer span(s) to extract from the model output. - max_answer_len (`int`): Maximum size of the answer to extract from the model's output. - undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer - """ - # Ensure we have batch axis - if start.ndim == 1: - start = start[None] - - if end.ndim == 1: - end = end[None] - - # Compute the score of each tuple(start, end) to be the real answer - outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1)) - - # Remove candidate with end < start and end - start > max_answer_len - candidates = np.tril(np.triu(outer), max_answer_len - 1) - - # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA) - scores_flat = candidates.flatten() - if topk == 1: - idx_sort = [np.argmax(scores_flat)] - elif len(scores_flat) < topk: - idx_sort = np.argsort(-scores_flat) - else: - idx = np.argpartition(-scores_flat, topk)[0:topk] - idx_sort = idx[np.argsort(-scores_flat[idx])] - - starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:] - desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(ends, undesired_tokens.nonzero()) - starts = starts[desired_spans] - ends = ends[desired_spans] - scores = candidates[0, starts, ends] - - return starts, ends, scores - def span_to_answer(self, text: str, start: int, end: int) -> Dict[str, Union[str, int]]: """ When decoding from token probabilities, this method maps token indexes to actual word in the initial context. diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 32ba979f78..dbdf37da4c 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -358,6 +358,9 @@ MODEL_FOR_CAUSAL_LM_MAPPING = None MODEL_FOR_CTC_MAPPING = None +MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None + + MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None @@ -463,6 +466,13 @@ class AutoModelForCTC(metaclass=DummyObject): requires_backends(self, ["torch"]) +class AutoModelForDocumentQuestionAnswering(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class AutoModelForImageClassification(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index bc3eb64ca4..69e11eeb31 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -265,6 +265,9 @@ class TFAlbertPreTrainedModel(metaclass=DummyObject): TF_MODEL_FOR_CAUSAL_LM_MAPPING = None +TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING = None + + TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING = None @@ -327,6 +330,13 @@ class TFAutoModelForCausalLM(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFAutoModelForDocumentQuestionAnswering(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFAutoModelForImageClassification(metaclass=DummyObject): _backends = ["tf"] diff --git a/src/transformers/utils/fx.py b/src/transformers/utils/fx.py index aec3c950ae..c08f6766c9 100644 --- a/src/transformers/utils/fx.py +++ b/src/transformers/utils/fx.py @@ -36,6 +36,7 @@ from ..models.auto.modeling_auto import ( MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_FOR_CTC_MAPPING_NAMES, + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING_NAMES, MODEL_FOR_MASKED_LM_MAPPING_NAMES, @@ -71,6 +72,7 @@ def _generate_supported_model_class_names( "seq2seq-lm": MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES, "speech-seq2seq": MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES, "multiple-choice": MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES, + "document-question-answering": MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES, "question-answering": MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES, "sequence-classification": MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES, "token-classification": MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, @@ -147,7 +149,6 @@ _SPECIAL_SUPPORTED_MODELS = [ "GPT2DoubleHeadsModel", "Speech2Text2Decoder", "TrOCRDecoder", - "LayoutLMForQuestionAnswering", # TODO: add support for them as it should be quite easy to do so (small blocking issues). # XLNetForQuestionAnswering, ] @@ -691,7 +692,7 @@ class HFTracer(Tracer): inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) elif model_class_name in [ *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), - "LayoutLMForQuestionAnswering", + *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES), "XLNetForQuestionAnswering", ]: inputs_dict["start_positions"] = torch.zeros(batch_size, dtype=torch.long, device=device) diff --git a/tests/models/layoutlm/test_modeling_layoutlm.py b/tests/models/layoutlm/test_modeling_layoutlm.py index cce3c9b3f4..16cacab88c 100644 --- a/tests/models/layoutlm/test_modeling_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_layoutlm.py @@ -12,12 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import copy import unittest from transformers import LayoutLMConfig, is_torch_available -from transformers.models.auto import get_values from transformers.testing_utils import require_torch, slow, torch_device from ...test_configuration_common import ConfigTester @@ -28,9 +25,6 @@ if is_torch_available(): import torch from transformers import ( - MODEL_FOR_MASKED_LM_MAPPING, - MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, LayoutLMForMaskedLM, LayoutLMForQuestionAnswering, LayoutLMForSequenceClassification, @@ -273,30 +267,6 @@ class LayoutLMModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_question_answering(*config_and_inputs) - def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): - inputs_dict = copy.deepcopy(inputs_dict) - if return_labels: - if model_class in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): - inputs_dict["labels"] = torch.zeros( - self.model_tester.batch_size, dtype=torch.long, device=torch_device - ) - elif model_class in [ - *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), - *get_values(MODEL_FOR_MASKED_LM_MAPPING), - ]: - inputs_dict["labels"] = torch.zeros( - (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device - ) - elif model_class.__name__ == "LayoutLMForQuestionAnswering": - inputs_dict["start_positions"] = torch.zeros( - self.model_tester.batch_size, dtype=torch.long, device=torch_device - ) - inputs_dict["end_positions"] = torch.zeros( - self.model_tester.batch_size, dtype=torch.long, device=torch_device - ) - - return inputs_dict - def prepare_layoutlm_batch_inputs(): # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: diff --git a/tests/models/layoutlm/test_modeling_tf_layoutlm.py b/tests/models/layoutlm/test_modeling_tf_layoutlm.py index 9323b0bb9b..4224f20a1d 100644 --- a/tests/models/layoutlm/test_modeling_tf_layoutlm.py +++ b/tests/models/layoutlm/test_modeling_tf_layoutlm.py @@ -13,13 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import unittest import numpy as np from transformers import LayoutLMConfig, is_tf_available -from transformers.models.auto import get_values from transformers.testing_utils import require_tf, slow from ...test_configuration_common import ConfigTester @@ -29,11 +27,6 @@ from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor, random_at if is_tf_available(): import tensorflow as tf - from transformers import ( - TF_MODEL_FOR_MASKED_LM_MAPPING, - TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - ) from transformers.models.layoutlm.modeling_tf_layoutlm import ( TF_LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, TFLayoutLMForMaskedLM, @@ -263,24 +256,6 @@ class TFLayoutLMModelTest(TFModelTesterMixin, unittest.TestCase): model = TFLayoutLMModel.from_pretrained(model_name) self.assertIsNotNone(model) - def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): - inputs_dict = copy.deepcopy(inputs_dict) - if return_labels: - if model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING): - inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) - elif model_class in [ - *get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING), - *get_values(TF_MODEL_FOR_MASKED_LM_MAPPING), - ]: - inputs_dict["labels"] = tf.zeros( - (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32 - ) - elif model_class.__name__ == "TFLayoutLMForQuestionAnswering": - inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) - inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) - - return inputs_dict - def prepare_layoutlm_batch_inputs(): # Here we prepare a batch of 2 sequences to test a LayoutLM forward pass on: diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py new file mode 100644 index 0000000000..7bf8ec99fb --- /dev/null +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -0,0 +1,280 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from transformers import MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, AutoTokenizer, is_vision_available +from transformers.pipelines import pipeline +from transformers.pipelines.document_question_answering import apply_tesseract +from transformers.testing_utils import ( + is_pipeline_test, + nested_simplify, + require_detectron2, + require_pytesseract, + require_tf, + require_torch, + require_vision, + slow, +) + +from .test_pipelines_common import ANY, PipelineTestCaseMeta + + +if is_vision_available(): + from PIL import Image + + from transformers.image_utils import load_image +else: + + class Image: + @staticmethod + def open(*args, **kwargs): + pass + + def load_image(_): + return None + + +# This is a pinned image from a specific revision of a document question answering space, hosted by HuggingFace, +# so we can expect it to be available. +INVOICE_URL = ( + "https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png" +) + + +@is_pipeline_test +@require_torch +@require_vision +class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta): + model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING + + @require_pytesseract + @require_vision + def get_test_pipeline(self, model, tokenizer, feature_extractor): + dqa_pipeline = pipeline( + "document-question-answering", model=model, tokenizer=tokenizer, feature_extractor=feature_extractor + ) + + image = INVOICE_URL + word_boxes = list(zip(*apply_tesseract(load_image(image), None, ""))) + question = "What is the placebo?" + examples = [ + { + "image": load_image(image), + "question": question, + }, + { + "image": image, + "question": question, + }, + { + "image": image, + "question": question, + "word_boxes": word_boxes, + }, + { + "image": None, + "question": question, + "word_boxes": word_boxes, + }, + ] + return dqa_pipeline, examples + + def run_pipeline_test(self, dqa_pipeline, examples): + outputs = dqa_pipeline(examples, top_k=2) + self.assertEqual( + outputs, + [ + [ + {"score": ANY(float), "answer": ANY(str), "start": ANY(int), "end": ANY(int)}, + {"score": ANY(float), "answer": ANY(str), "start": ANY(int), "end": ANY(int)}, + ] + ] + * 4, + ) + + @require_torch + @require_detectron2 + @require_pytesseract + def test_small_model_pt(self): + dqa_pipeline = pipeline("document-question-answering", model="hf-internal-testing/tiny-random-layoutlmv2") + image = INVOICE_URL + question = "How many cats are there?" + + expected_output = [ + { + "score": 0.0001, + "answer": "2312/2019 DUE DATE 26102/2019 ay DESCRIPTION UNIT PRICE", + "start": 38, + "end": 45, + }, + {"score": 0.0001, "answer": "2312/2019 DUE", "start": 38, "end": 39}, + ] + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) + + outputs = dqa_pipeline({"image": image, "question": question}, top_k=2) + self.assertEqual(nested_simplify(outputs, decimals=4), expected_output) + + # This image does not detect ANY text in it, meaning layoutlmv2 should fail. + # Empty answer probably + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual(outputs, []) + + # We can optionnally pass directly the words and bounding boxes + image = "./tests/fixtures/tests_samples/COCO/000000039769.png" + words = [] + boxes = [] + outputs = dqa_pipeline(image=image, question=question, words=words, boxes=boxes, top_k=2) + self.assertEqual(outputs, []) + + # TODO: Enable this once hf-internal-testing/tiny-random-donut is implemented + # @require_torch + # def test_small_model_pt_donut(self): + # dqa_pipeline = pipeline("document-question-answering", model="hf-internal-testing/tiny-random-donut") + # # dqa_pipeline = pipeline("document-question-answering", model="../tiny-random-donut") + # image = "https://templates.invoicehome.com/invoice-template-us-neat-750px.png" + # question = "How many cats are there?" + # + # outputs = dqa_pipeline(image=image, question=question, top_k=2) + # self.assertEqual( + # nested_simplify(outputs, decimals=4), [{"score": 0.8799, "answer": "2"}, {"score": 0.296, "answer": "1"}] + # ) + + @slow + @require_torch + @require_detectron2 + @require_pytesseract + def test_large_model_pt(self): + dqa_pipeline = pipeline( + "document-question-answering", + model="tiennvcs/layoutlmv2-base-uncased-finetuned-docvqa", + revision="9977165", + ) + image = INVOICE_URL + question = "What is the invoice number?" + + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + + outputs = dqa_pipeline({"image": image, "question": question}, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + + outputs = dqa_pipeline( + [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2 + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9966, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0009, "answer": "us-001", "start": 15, "end": 15}, + ], + ] + * 2, + ) + + @slow + @require_torch + @require_pytesseract + @require_vision + def test_large_model_pt_layoutlm(self): + tokenizer = AutoTokenizer.from_pretrained( + "impira/layoutlm-document-qa", revision="3dc6de3", add_prefix_space=True + ) + dqa_pipeline = pipeline( + "document-question-answering", + model="impira/layoutlm-document-qa", + tokenizer=tokenizer, + revision="3dc6de3", + ) + image = INVOICE_URL + question = "What is the invoice number?" + + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + ], + ) + + outputs = dqa_pipeline({"image": image, "question": question}, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + ], + ) + + outputs = dqa_pipeline( + [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2 + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + ] + ] + * 2, + ) + + word_boxes = list(zip(*apply_tesseract(load_image(image), None, ""))) + + # This model should also work if `image` is set to None + outputs = dqa_pipeline({"image": None, "word_boxes": word_boxes, "question": question}, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9998, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.0, "answer": "INVOICE # us-001", "start": 13, "end": 15}, + ], + ) + + @slow + @require_torch + def test_large_model_pt_donut(self): + dqa_pipeline = pipeline( + "document-question-answering", + model="naver-clova-ix/donut-base-finetuned-docvqa", + tokenizer=AutoTokenizer.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa"), + feature_extractor="naver-clova-ix/donut-base-finetuned-docvqa", + ) + + image = INVOICE_URL + question = "What is the invoice number?" + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual(nested_simplify(outputs, decimals=4), {"answer": "us-001"}) + + @require_tf + @unittest.skip("Document question answering not implemented in TF") + def test_small_model_tf(self): + pass diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 05921334a6..6c4814c1a8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -89,6 +89,7 @@ if is_torch_available(): MODEL_FOR_AUDIO_XVECTOR_MAPPING, MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, MODEL_FOR_MASKED_LM_MAPPING, @@ -172,7 +173,10 @@ class ModelTesterMixin: if return_labels: if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING): inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device) - elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + elif model_class in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), + *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING), + ]: inputs_dict["start_positions"] = torch.zeros( self.model_tester.batch_size, dtype=torch.long, device=torch_device ) @@ -542,7 +546,10 @@ class ModelTesterMixin: if "labels" in inputs_dict: correct_outlen += 1 # loss is added to beginning # Question Answering model returns start_logits and end_logits - if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING): + if model_class in [ + *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING), + *get_values(MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING), + ]: correct_outlen += 1 # start_logits and end_logits instead of only 1 output if "past_key_values" in outputs: correct_outlen += 1 # past_key_values have been returned diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index f3608f4b22..0ef457c035 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -61,6 +61,7 @@ if is_tf_available(): from transformers import ( TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, @@ -149,7 +150,10 @@ class TFModelTesterMixin: if return_labels: if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING): inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32) - elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING): + elif model_class in [ + *get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING), + *get_values(TF_MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING), + ]: inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) elif model_class in [