Enabling `Tapex` in table question answering pipeline. (#16663)
* Enabling `Tapex` in table question answering pipeline. * Questions are independant for Tapex, making the test respect that. * Missing extra space.
This commit is contained in:
parent
442dc45645
commit
195fbbb6cf
|
@ -105,9 +105,10 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||
else MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING
|
||||
)
|
||||
|
||||
self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
|
||||
getattr(self.model.config, "num_aggregation_labels")
|
||||
self.aggregate = bool(getattr(self.model.config, "aggregation_labels", None)) and bool(
|
||||
getattr(self.model.config, "num_aggregation_labels", None)
|
||||
)
|
||||
self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
|
||||
|
||||
def batch_inference(self, **inputs):
|
||||
return self.model(**inputs)
|
||||
|
@ -335,7 +336,13 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||
forward_params["sequential"] = sequential
|
||||
return preprocess_params, forward_params, {}
|
||||
|
||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation="drop_rows_to_fit"):
|
||||
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):
|
||||
if truncation is None:
|
||||
if self.type == "tapas":
|
||||
truncation = "drop_rows_to_fit"
|
||||
else:
|
||||
truncation = "do_not_truncate"
|
||||
|
||||
table, query = pipeline_input["table"], pipeline_input["query"]
|
||||
if table.empty:
|
||||
raise ValueError("table is empty")
|
||||
|
@ -347,7 +354,14 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||
|
||||
def _forward(self, model_inputs, sequential=False):
|
||||
table = model_inputs.pop("table")
|
||||
outputs = self.sequential_inference(**model_inputs) if sequential else self.batch_inference(**model_inputs)
|
||||
|
||||
if self.type == "tapas":
|
||||
if sequential:
|
||||
outputs = self.sequential_inference(**model_inputs)
|
||||
else:
|
||||
outputs = self.batch_inference(**model_inputs)
|
||||
else:
|
||||
outputs = self.model.generate(**model_inputs)
|
||||
model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
|
||||
return model_outputs
|
||||
|
||||
|
@ -355,37 +369,40 @@ class TableQuestionAnsweringPipeline(Pipeline):
|
|||
inputs = model_outputs["model_inputs"]
|
||||
table = model_outputs["table"]
|
||||
outputs = model_outputs["outputs"]
|
||||
if self.aggregate:
|
||||
logits, logits_agg = outputs[:2]
|
||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
|
||||
answer_coordinates_batch, agg_predictions = predictions
|
||||
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
|
||||
if self.type == "tapas":
|
||||
if self.aggregate:
|
||||
logits, logits_agg = outputs[:2]
|
||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
|
||||
answer_coordinates_batch, agg_predictions = predictions
|
||||
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
|
||||
|
||||
no_agg_label_index = self.model.config.no_aggregation_label_index
|
||||
aggregators_prefix = {
|
||||
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
|
||||
}
|
||||
no_agg_label_index = self.model.config.no_aggregation_label_index
|
||||
aggregators_prefix = {
|
||||
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
|
||||
}
|
||||
else:
|
||||
logits = outputs[0]
|
||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
|
||||
answer_coordinates_batch = predictions[0]
|
||||
aggregators = {}
|
||||
aggregators_prefix = {}
|
||||
answers = []
|
||||
for index, coordinates in enumerate(answer_coordinates_batch):
|
||||
cells = [table.iat[coordinate] for coordinate in coordinates]
|
||||
aggregator = aggregators.get(index, "")
|
||||
aggregator_prefix = aggregators_prefix.get(index, "")
|
||||
answer = {
|
||||
"answer": aggregator_prefix + ", ".join(cells),
|
||||
"coordinates": coordinates,
|
||||
"cells": [table.iat[coordinate] for coordinate in coordinates],
|
||||
}
|
||||
if aggregator:
|
||||
answer["aggregator"] = aggregator
|
||||
|
||||
answers.append(answer)
|
||||
if len(answer) == 0:
|
||||
raise PipelineException("Empty answer")
|
||||
else:
|
||||
logits = outputs[0]
|
||||
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
|
||||
answer_coordinates_batch = predictions[0]
|
||||
aggregators = {}
|
||||
aggregators_prefix = {}
|
||||
answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
|
||||
|
||||
answers = []
|
||||
for index, coordinates in enumerate(answer_coordinates_batch):
|
||||
cells = [table.iat[coordinate] for coordinate in coordinates]
|
||||
aggregator = aggregators.get(index, "")
|
||||
aggregator_prefix = aggregators_prefix.get(index, "")
|
||||
answer = {
|
||||
"answer": aggregator_prefix + ", ".join(cells),
|
||||
"coordinates": coordinates,
|
||||
"cells": [table.iat[coordinate] for coordinate in coordinates],
|
||||
}
|
||||
if aggregator:
|
||||
answer["aggregator"] = aggregator
|
||||
|
||||
answers.append(answer)
|
||||
if len(answer) == 0:
|
||||
raise PipelineException("Empty answer")
|
||||
return answers if len(answers) > 1 else answers[0]
|
||||
|
|
|
@ -632,3 +632,31 @@ class TQAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
|||
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
|
||||
]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_large_model_pt_tapex(self):
|
||||
model_id = "microsoft/tapex-large-finetuned-wtq"
|
||||
table_querier = pipeline(
|
||||
"table-question-answering",
|
||||
model=model_id,
|
||||
)
|
||||
data = {
|
||||
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||
"Age": ["56", "45", "59"],
|
||||
"Number of movies": ["87", "53", "69"],
|
||||
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||
}
|
||||
queries = [
|
||||
"How many movies has George Clooney played in?",
|
||||
"How old is Mr Clooney ?",
|
||||
"What's the date of birth of Leonardo ?",
|
||||
]
|
||||
results = table_querier(data, queries, sequential=True)
|
||||
|
||||
expected_results = [
|
||||
{"answer": " 69"},
|
||||
{"answer": " 59"},
|
||||
{"answer": " 10 june 1996"},
|
||||
]
|
||||
self.assertListEqual(results, expected_results)
|
||||
|
|
Loading…
Reference in New Issue