Implement multiple span support for DocumentQuestionAnswering (#19204)

* Implement multiple span support

* Address comments

* Add tests + fix bugs
This commit is contained in:
Ankur Goyal 2022-10-11 07:47:55 -07:00 committed by GitHub
parent ab856f68df
commit a3008c5a6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 196 additions and 64 deletions

View File

@ -25,7 +25,7 @@ from ..utils import (
is_vision_available, is_vision_available,
logging, logging,
) )
from .base import PIPELINE_INIT_ARGS, Pipeline from .base import PIPELINE_INIT_ARGS, ChunkPipeline
from .question_answering import select_starts_ends from .question_answering import select_starts_ends
@ -49,7 +49,7 @@ logger = logging.get_logger(__name__)
# normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py. # normalize_bbox() and apply_tesseract() are derived from apply_tesseract in models/layoutlmv3/feature_extraction_layoutlmv3.py.
# However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an # However, because the pipeline may evolve from what layoutlmv3 currently does, it's copied (vs. imported) to avoid creating an
# unecessary dependency. # unnecessary dependency.
def normalize_box(box, width, height): def normalize_box(box, width, height):
return [ return [
int(1000 * (box[0] / width)), int(1000 * (box[0] / width)),
@ -99,7 +99,7 @@ class ModelType(ExplicitEnum):
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class DocumentQuestionAnsweringPipeline(Pipeline): class DocumentQuestionAnsweringPipeline(ChunkPipeline):
# TODO: Update task_summary docs to include an example with document QA and then update the first sentence # TODO: Update task_summary docs to include an example with document QA and then update the first sentence
""" """
Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are Document Question Answering pipeline using any `AutoModelForDocumentQuestionAnswering`. The inputs/outputs are
@ -234,6 +234,8 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
- **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided - **end** (`int`) -- The end word index of the answer (in the OCR'd version of the input or provided
`word_boxes`). `word_boxes`).
- **answer** (`str`) -- The answer to the question. - **answer** (`str`) -- The answer to the question.
- **words** (`list[int]`) -- The index of each word/box pair that is in the answer
- **page** (`int`) -- The page of the answer
""" """
if isinstance(question, str): if isinstance(question, str):
inputs = {"question": question, "image": image} inputs = {"question": question, "image": image}
@ -243,7 +245,24 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
inputs = image inputs = image
return super().__call__(inputs, **kwargs) return super().__call__(inputs, **kwargs)
def preprocess(self, input, lang=None, tesseract_config=""): def preprocess(
self,
input,
padding="do_not_pad",
doc_stride=None,
max_seq_len=None,
word_boxes: Tuple[str, List[float]] = None,
lang=None,
tesseract_config="",
):
# NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
# to support documents with enough tokens that overflow the model's window
if max_seq_len is None:
max_seq_len = self.tokenizer.model_max_length
if doc_stride is None:
doc_stride = min(max_seq_len // 2, 256)
image = None image = None
image_features = {} image_features = {}
if input.get("image", None) is not None: if input.get("image", None) is not None:
@ -291,9 +310,15 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
).input_ids, ).input_ids,
"return_dict_in_generate": True, "return_dict_in_generate": True,
} }
p_mask = None yield {
word_ids = None **encoding,
words = None "p_mask": None,
"word_ids": None,
"words": None,
"page": None,
"output_attentions": True,
"is_last": True,
}
else: else:
tokenizer_kwargs = {} tokenizer_kwargs = {}
if self.model_type == ModelType.LayoutLM: if self.model_type == ModelType.LayoutLM:
@ -306,21 +331,15 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
tokenizer_kwargs["boxes"] = [boxes] tokenizer_kwargs["boxes"] = [boxes]
encoding = self.tokenizer( encoding = self.tokenizer(
padding=padding,
max_length=max_seq_len,
stride=doc_stride,
return_token_type_ids=True, return_token_type_ids=True,
return_tensors=self.framework, truncation="only_second",
# TODO: In a future PR, use these feature to handle sequences whose length is longer than return_overflowing_tokens=True,
# the maximum allowed by the model. Currently, the tokenizer will produce a sequence that
# may be too long for the model to handle.
# truncation="only_second",
# return_overflowing_tokens=True,
**tokenizer_kwargs, **tokenizer_kwargs,
) )
if "pixel_values" in image_features:
encoding["image"] = image_features.pop("pixel_values")
# TODO: For now, this should always be num_spans == 1 given the flags we've passed in above, but the
# code is written to naturally handle multiple spans at the right time.
num_spans = len(encoding["input_ids"]) num_spans = len(encoding["input_ids"])
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
@ -328,6 +347,13 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
# This logic mirrors the logic in the question_answering pipeline # This logic mirrors the logic in the question_answering pipeline
p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)] p_mask = [[tok != 1 for tok in encoding.sequence_ids(span_id)] for span_id in range(num_spans)]
for span_idx in range(num_spans): for span_idx in range(num_spans):
if self.framework == "pt":
span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
if "pixel_values" in image_features:
span_encoding["image"] = image_features["pixel_values"]
else:
raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
input_ids_span_idx = encoding["input_ids"][span_idx] input_ids_span_idx = encoding["input_ids"][span_idx]
# keep the cls_token unmasked (some models use it to indicate unanswerable questions) # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
if self.tokenizer.cls_token_id is not None: if self.tokenizer.cls_token_id is not None:
@ -335,15 +361,14 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
for cls_index in cls_indices: for cls_index in cls_indices:
p_mask[span_idx][cls_index] = 0 p_mask[span_idx][cls_index] = 0
# For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000] # For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
# for SEP tokens, and the word's bounding box for words in the original document. # for SEP tokens, and the word's bounding box for words in the original document.
if "boxes" not in tokenizer_kwargs: if "boxes" not in tokenizer_kwargs:
bbox = [] bbox = []
for batch_index in range(num_spans):
for input_id, sequence_id, word_id in zip( for input_id, sequence_id, word_id in zip(
encoding.input_ids[batch_index], encoding.input_ids[span_idx],
encoding.sequence_ids(batch_index), encoding.sequence_ids(span_idx),
encoding.word_ids(batch_index), encoding.word_ids(span_idx),
): ):
if sequence_id == 1: if sequence_id == 1:
bbox.append(boxes[word_id]) bbox.append(boxes[word_id])
@ -352,41 +377,50 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
else: else:
bbox.append([0] * 4) bbox.append([0] * 4)
if self.framework == "tf": if self.framework == "pt":
raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline") span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)
elif self.framework == "pt": elif self.framework == "tf":
encoding["bbox"] = torch.tensor([bbox]) raise ValueError("Unsupported: Tensorflow preprocessing for DocumentQuestionAnsweringPipeline")
yield {
word_ids = [encoding.word_ids(i) for i in range(num_spans)] **span_encoding,
"p_mask": p_mask[span_idx],
return {**encoding, "p_mask": p_mask, "word_ids": word_ids, "words": words} "word_ids": encoding.word_ids(span_idx),
"words": words,
"is_last": span_idx == num_spans - 1,
}
def _forward(self, model_inputs): def _forward(self, model_inputs):
p_mask = model_inputs.pop("p_mask", None) p_mask = model_inputs.pop("p_mask", None)
word_ids = model_inputs.pop("word_ids", None) word_ids = model_inputs.pop("word_ids", None)
words = model_inputs.pop("words", None) words = model_inputs.pop("words", None)
is_last = model_inputs.pop("is_last", False)
if "overflow_to_sample_mapping" in model_inputs:
model_inputs.pop("overflow_to_sample_mapping")
if self.model_type == ModelType.VisionEncoderDecoder: if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs) model_outputs = self.model.generate(**model_inputs)
else: else:
model_outputs = self.model(**model_inputs) model_outputs = self.model(**model_inputs)
model_outputs = {k: v for (k, v) in model_outputs.items()}
model_outputs["p_mask"] = p_mask model_outputs["p_mask"] = p_mask
model_outputs["word_ids"] = word_ids model_outputs["word_ids"] = word_ids
model_outputs["words"] = words model_outputs["words"] = words
model_outputs["attention_mask"] = model_inputs.get("attention_mask", None) model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
model_outputs["is_last"] = is_last
return model_outputs return model_outputs
def postprocess(self, model_outputs, top_k=1, **kwargs): def postprocess(self, model_outputs, top_k=1, **kwargs):
if self.model_type == ModelType.VisionEncoderDecoder: if self.model_type == ModelType.VisionEncoderDecoder:
answers = self.postprocess_donut(model_outputs) answers = [self.postprocess_encoder_decoder_single(o) for o in model_outputs]
else: else:
answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs) answers = self.postprocess_extractive_qa(model_outputs, top_k=top_k, **kwargs)
answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k] answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]
return answers return answers
def postprocess_donut(self, model_outputs, **kwargs): def postprocess_encoder_decoder_single(self, model_outputs, **kwargs):
sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0] sequence = self.tokenizer.batch_decode(model_outputs.sequences)[0]
# TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer # TODO: A lot of this logic is specific to Donut and should probably be handled in the tokenizer
@ -400,41 +434,40 @@ class DocumentQuestionAnsweringPipeline(Pipeline):
answer = re.search(r"<s_answer>(.*)</s_answer>", sequence) answer = re.search(r"<s_answer>(.*)</s_answer>", sequence)
if answer is not None: if answer is not None:
ret["answer"] = answer.group(1).strip() ret["answer"] = answer.group(1).strip()
return [ret] return ret
def postprocess_extractive_qa( def postprocess_extractive_qa(
self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs self, model_outputs, top_k=1, handle_impossible_answer=False, max_answer_len=15, **kwargs
): ):
min_null_score = 1000000 # large and positive min_null_score = 1000000 # large and positive
answers = [] answers = []
words = model_outputs["words"] for output in model_outputs:
words = output["words"]
# TODO: Currently, we expect the length of model_outputs to be 1, because we do not stride starts, ends, scores, min_null_score = select_starts_ends(
# in the preprocessor code. When we implement that, we'll either need to handle tensors of size start=output["start_logits"],
# > 1 or use the ChunkPipeline and handle multiple outputs (each of size = 1). end=output["end_logits"],
starts, ends, scores, min_null_score = select_starts_ends( p_mask=output["p_mask"],
model_outputs["start_logits"], attention_mask=output["attention_mask"].numpy()
model_outputs["end_logits"], if output.get("attention_mask", None) is not None
model_outputs["p_mask"], else None,
model_outputs["attention_mask"].numpy() if model_outputs.get("attention_mask", None) is not None else None, min_null_score=min_null_score,
min_null_score, top_k=top_k,
top_k, handle_impossible_answer=handle_impossible_answer,
handle_impossible_answer, max_answer_len=max_answer_len,
max_answer_len, )
) word_ids = output["word_ids"]
for start, end, score in zip(starts, ends, scores):
word_ids = model_outputs["word_ids"][0] word_start, word_end = word_ids[start], word_ids[end]
for start, eend, score in zip(starts, ends, scores): if word_start is not None and word_end is not None:
word_start, word_end = word_ids[start], word_ids[eend] answers.append(
if word_start is not None and word_end is not None: {
answers.append( "score": float(score),
{ "answer": " ".join(words[word_start : word_end + 1]),
"score": float(score), # XXX Write a test that verifies the result is JSON-serializable "start": word_start,
"answer": " ".join(words[word_start : word_end + 1]), "end": word_end,
"start": word_start, }
"end": word_end, )
}
)
if handle_impossible_answer: if handle_impossible_answer:
answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0}) answers.append({"score": min_null_score, "answer": "", "start": 0, "end": 0})

View File

@ -191,6 +191,52 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
* 2, * 2,
) )
@slow
@require_torch
@require_detectron2
@require_pytesseract
def test_large_model_pt_chunk(self):
dqa_pipeline = pipeline(
"document-question-answering",
model="tiennvcs/layoutlmv2-base-uncased-finetuned-docvqa",
revision="9977165",
max_seq_len=50,
)
image = INVOICE_URL
question = "What is the invoice number?"
outputs = dqa_pipeline(image=image, question=question, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
],
)
outputs = dqa_pipeline({"image": image, "question": question}, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
],
)
outputs = dqa_pipeline(
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9967, "answer": "1102/2019", "start": 22, "end": 22},
{"score": 0.996, "answer": "us-001", "start": 15, "end": 15},
]
]
* 2,
)
@slow @slow
@require_torch @require_torch
@require_pytesseract @require_pytesseract
@ -252,6 +298,59 @@ class DocumentQuestionAnsweringPipelineTests(unittest.TestCase, metaclass=Pipeli
], ],
) )
@slow
@require_torch
@require_pytesseract
@require_vision
def test_large_model_pt_layoutlm_chunk(self):
tokenizer = AutoTokenizer.from_pretrained(
"impira/layoutlm-document-qa", revision="3dc6de3", add_prefix_space=True
)
dqa_pipeline = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
tokenizer=tokenizer,
revision="3dc6de3",
max_seq_len=50,
)
image = INVOICE_URL
question = "What is the invoice number?"
outputs = dqa_pipeline(image=image, question=question, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
],
)
outputs = dqa_pipeline(
[{"image": image, "question": question}, {"image": image, "question": question}], top_k=2
)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
]
]
* 2,
)
word_boxes = list(zip(*apply_tesseract(load_image(image), None, "")))
# This model should also work if `image` is set to None
outputs = dqa_pipeline({"image": None, "word_boxes": word_boxes, "question": question}, top_k=2)
self.assertEqual(
nested_simplify(outputs, decimals=4),
[
{"score": 0.9999, "answer": "us-001", "start": 15, "end": 15},
{"score": 0.9924, "answer": "us-001", "start": 15, "end": 15},
],
)
@slow @slow
@require_torch @require_torch
def test_large_model_pt_donut(self): def test_large_model_pt_donut(self):