diff --git a/src/transformers/pipelines/document_question_answering.py b/src/transformers/pipelines/document_question_answering.py index 2023ce9e01..1b14c1f480 100644 --- a/src/transformers/pipelines/document_question_answering.py +++ b/src/transformers/pipelines/document_question_answering.py @@ -25,7 +25,7 @@ from ..utils import ( is_vision_available, logging, ) -from .base import PIPELINE_INIT_ARGS, Pipeline +from .base import PIPELINE_INIT_ARGS, ChunkPipeline from .question_answering import select_starts_ends @@ -49,7 +49,7 @@ logger = logging.get_logger(__name__) # normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py. # However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an -# unecessary dependency. +# unnecessary dependency. def normalize_box(box, width, height): return [ int(1000 * (box[0] / width)), @@ -99,7 +99,7 @@ class ModelType(ExplicitEnum): @add_end_docstrings(PIPELINE_INIT_ARGS) -class DocumentQuestionAnsweringPipeline(Pipeline): +class DocumentQuestionAnsweringPipeline(ChunkPipeline): # TODO: Update task_summary docs to include an example with document QA and then update the first sentence """ Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are @@ -234,6 +234,8 @@ class DocumentQuestionAnsweringPipeline(Pipeline): - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided `word_boxes`). - **answer** (`str`) -- The answer to the question. + - **words** (`list[int]`) -- The index of each word/box pair that is in the answer + - **page** (`int`) -- The page of the answer """ if isinstance(question, str): inputs = {"question": question, "image": image} @@ -243,7 +245,24 @@ class DocumentQuestionAnsweringPipeline(Pipeline): inputs = image return super().__call__(inputs, **kwargs) - def preprocess(self, input, lang=None, tesseract_config=""): + def preprocess( + self, + input, + padding="do_not_pad", + doc_stride=None, + max_seq_len=None, + word_boxes: Tuple[str, List[float]] = None, + lang=None, + tesseract_config="", + ): + # NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR + # to support documents with enough tokens that overflow the model's window + if max_seq_len is None: + max_seq_len = self.tokenizer.model_max_length + + if doc_stride is None: + doc_stride = min(max_seq_len // 2, 256) + image = None image_features = {} if input.get("image", None) is not None: @@ -291,9 +310,15 @@ class DocumentQuestionAnsweringPipeline(Pipeline): ).input_ids, "return_dict_in_generate": True, } - p_mask = None - word_ids = None - words = None + yield { + **encoding, + "p_mask": None, + "word_ids": None, + "words": None, + "page": None, + "output_attentions": True, + "is_last": True, + } else: tokenizer_kwargs = {} if self.model_type == ModelType.LayoutLM: @@ -306,21 +331,15 @@ class DocumentQuestionAnsweringPipeline(Pipeline): tokenizer_kwargs["boxes"] = [boxes] encoding = self.tokenizer( + padding=padding, + max_length=max_seq_len, + stride=doc_stride, return_token_type_ids=True, - return_tensors=self.framework, - # TODO: In a future PR, use these feature to handle sequences whose length is longer than - # the maximum allowed by the model. Currently, the tokenizer will produce a sequence that - # may be too long for the model to handle. - # truncation="only_second", - # return_overflowing_tokens=True, + truncation="only_second", + return_overflowing_tokens=True, **tokenizer_kwargs, ) - if "pixel_values" in image_features: - encoding["image"] = image_features.pop("pixel_values") - - # TODO: For now, this should always be num_spans == 1 given the flags we've passed in above, but the - # code is written to naturally handle multiple spans at the right time. num_spans = len(encoding["input_ids"]) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) @@ -328,6 +347,13 @@ class DocumentQuestionAnsweringPipeline(Pipeline): # This logic mirrors the logic in the question_answering pipeline p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)] for span_idx in range(num_spans): + if self.framework == "pt": + span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()} + if "pixel_values" in image_features: + span_encoding["image"] = image_features["pixel_values"] + else: + raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") + input_ids_span_idx = encoding["input_ids"][span_idx] # keep the cls_token unmasked (some models use it to indicate unanswerable questions) if self.tokenizer.cls_token_id is not None: @@ -335,15 +361,14 @@ class DocumentQuestionAnsweringPipeline(Pipeline): for cls_index in cls_indices: p_mask[span_idx][cls_index] = 0 - # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] - # for SEP tokens, and the word's bounding box for words in the original document. - if "boxes" not in tokenizer_kwargs: - bbox = [] - for batch_index in range(num_spans): + # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] + # for SEP tokens, and the word's bounding box for words in the original document. + if "boxes" not in tokenizer_kwargs: + bbox = [] for input_id, sequence_id, word_id in zip( - encoding.input_ids[batch_index], - encoding.sequence_ids(batch_index), - encoding.word_ids(batch_index), + encoding.input_ids[span_idx], + encoding.sequence_ids(span_idx), + encoding.word_ids(span_idx), ): if sequence_id == 1: bbox.append(boxes[word_id]) @@ -352,41 +377,50 @@ class DocumentQuestionAnsweringPipeline(Pipeline): else: bbox.append([0] * 4) - if self.framework == "tf": - raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") - elif self.framework == "pt": - encoding["bbox"] = torch.tensor([bbox]) - - word_ids = [encoding.word_ids(i) for i in range(num_spans)] - - return {**encoding, "p_mask": p_mask, "word_ids": word_ids, "words": words} + if self.framework == "pt": + span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0) + elif self.framework == "tf": + raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") + yield { + **span_encoding, + "p_mask": p_mask[span_idx], + "word_ids": encoding.word_ids(span_idx), + "words": words, + "is_last": span_idx == num_spans - 1, + } def _forward(self, model_inputs): p_mask = model_inputs.pop("p_mask", None) word_ids = model_inputs.pop("word_ids", None) words = model_inputs.pop("words", None) + is_last = model_inputs.pop("is_last", False) + + if "overflow_to_sample_mapping" in model_inputs: + model_inputs.pop("overflow_to_sample_mapping") if self.model_type == ModelType.VisionEncoderDecoder: model_outputs = self.model.generate(**model_inputs) else: model_outputs = self.model(**model_inputs) + model_outputs = {k: v for (k, v) in model_outputs.items()} model_outputs["p_mask"] = p_mask model_outputs["word_ids"] = word_ids model_outputs["words"] = words model_outputs["attention_mask"] = model_inputs.get("attention_mask", None) + model_outputs["is_last"] = is_last return model_outputs def postprocess(self, model_outputs, top_k=1, **kwargs): if self.model_type == ModelType.VisionEncoderDecoder: - answers = self.postprocess_donut(model_outputs) + answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs] else: answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] return answers - def postprocess_donut(self, model_outputs, **kwargs): + def postprocess_encoder_decoder_single(self, model_outputs, **kwargs): sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer @@ -400,41 +434,40 @@ class DocumentQuestionAnsweringPipeline(Pipeline): answer = re.search(r"(.*)", sequence) if answer is not None: ret["answer"] = answer.group(1).strip() - return [ret] + return ret def postprocess_extractive_qa( self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs ): min_null_score = 1000000 # large and positive answers = [] - words = model_outputs["words"] + for output in model_outputs: + words = output["words"] - # TODO: Currently, we expect the length of model_outputs to be 1, because we do not stride - # in the preprocessor code. When we implement that, we'll either need to handle tensors of size - # > 1 or use the ChunkPipeline and handle multiple outputs (each of size = 1). - starts, ends, scores, min_null_score = select_starts_ends( - model_outputs["start_logits"], - model_outputs["end_logits"], - model_outputs["p_mask"], - model_outputs["attention_mask"].numpy() if model_outputs.get("attention_mask", None) is not None else None, - min_null_score, - top_k, - handle_impossible_answer, - max_answer_len, - ) - - word_ids = model_outputs["word_ids"][0] - for start, eend, score in zip(starts, ends, scores): - word_start, word_end = word_ids[start], word_ids[eend] - if word_start is not None and word_end is not None: - answers.append( - { - "score": float(score), # XXX Write a test that verifies the result is JSON-serializable - "answer": " ".join(words[word_start : word_end + 1]), - "start": word_start, - "end": word_end, - } - ) + starts, ends, scores, min_null_score = select_starts_ends( + start=output["start_logits"], + end=output["end_logits"], + p_mask=output["p_mask"], + attention_mask=output["attention_mask"].numpy() + if output.get("attention_mask", None) is not None + else None, + min_null_score=min_null_score, + top_k=top_k, + handle_impossible_answer=handle_impossible_answer, + max_answer_len=max_answer_len, + ) + word_ids = output["word_ids"] + for start, end, score in zip(starts, ends, scores): + word_start, word_end = word_ids[start], word_ids[end] + if word_start is not None and word_end is not None: + answers.append( + { + "score": float(score), + "answer": " ".join(words[word_start : word_end + 1]), + "start": word_start, + "end": word_end, + } + ) if handle_impossible_answer: answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0}) diff --git a/tests/pipelines/test_pipelines_document_question_answering.py b/tests/pipelines/test_pipelines_document_question_answering.py index bea8335c5c..fa272d6492 100644 --- a/tests/pipelines/test_pipelines_document_question_answering.py +++ b/tests/pipelines/test_pipelines_document_question_answering.py @@ -191,6 +191,52 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli * 2, ) + @slow + @require_torch + @require_detectron2 + @require_pytesseract + def test_large_model_pt_chunk(self): + dqa_pipeline = pipeline( + "document-question-answering", + model="tiennvcs/layoutlmv2-base-uncased-finetuned-docvqa", + revision="9977165", + max_seq_len=50, + ) + image = INVOICE_URL + question = "What is the invoice number?" + + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, + {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + + outputs = dqa_pipeline({"image": image, "question": question}, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, + {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + + outputs = dqa_pipeline( + [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2 + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22}, + {"score": 0.996, "answer": "us-001", "start": 15, "end": 15}, + ] + ] + * 2, + ) + @slow @require_torch @require_pytesseract @@ -252,6 +298,59 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli ], ) + @slow + @require_torch + @require_pytesseract + @require_vision + def test_large_model_pt_layoutlm_chunk(self): + tokenizer = AutoTokenizer.from_pretrained( + "impira/layoutlm-document-qa", revision="3dc6de3", add_prefix_space=True + ) + dqa_pipeline = pipeline( + "document-question-answering", + model="impira/layoutlm-document-qa", + tokenizer=tokenizer, + revision="3dc6de3", + max_seq_len=50, + ) + image = INVOICE_URL + question = "What is the invoice number?" + + outputs = dqa_pipeline(image=image, question=question, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + + outputs = dqa_pipeline( + [{"image": image, "question": question}, {"image": image, "question": question}], top_k=2 + ) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + [ + {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + ] + ] + * 2, + ) + + word_boxes = list(zip(*apply_tesseract(load_image(image), None, ""))) + + # This model should also work if `image` is set to None + outputs = dqa_pipeline({"image": None, "word_boxes": word_boxes, "question": question}, top_k=2) + self.assertEqual( + nested_simplify(outputs, decimals=4), + [ + {"score": 0.9999, "answer": "us-001", "start": 15, "end": 15}, + {"score": 0.9924, "answer": "us-001", "start": 15, "end": 15}, + ], + ) + @slow @require_torch def test_large_model_pt_donut(self):