TableQuestionAnsweringPipeline (#9145)

* AutoModelForTableQuestionAnswering

* TableQuestionAnsweringPipeline

* Apply suggestions from Patrick's code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Sylvain and Patrick comments

* Better PyTorch/TF error message

* Add integration tests

* Argument Handler naming

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>

* Fix docs to appease the documentation gods

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Lysandre Debut 2020-12-16 12:31:50 -05:00 committed by GitHub
parent 07384baf7a
commit 1c1a2ffbff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 602 additions and 27 deletions

View File

@ -34,6 +34,7 @@ There are two categories of pipeline abstractions to be aware about:
- :class:`~transformers.TranslationPipeline` - :class:`~transformers.TranslationPipeline`
- :class:`~transformers.ZeroShotClassificationPipeline` - :class:`~transformers.ZeroShotClassificationPipeline`
- :class:`~transformers.Text2TextGenerationPipeline` - :class:`~transformers.Text2TextGenerationPipeline`
- :class:`~transformers.TableQuestionAnsweringPipeline`
The pipeline abstraction The pipeline abstraction
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -91,6 +92,13 @@ SummarizationPipeline
:special-members: __call__ :special-members: __call__
:members: :members:
TableQuestionAnsweringPipeline
=======================================================================================================================
.. autoclass:: transformers.TableQuestionAnsweringPipeline
:special-members: __call__
TextClassificationPipeline TextClassificationPipeline
======================================================================================================================= =======================================================================================================================

View File

@ -190,6 +190,7 @@ from .pipelines import (
PipelineDataFormat, PipelineDataFormat,
QuestionAnsweringPipeline, QuestionAnsweringPipeline,
SummarizationPipeline, SummarizationPipeline,
TableQuestionAnsweringPipeline,
Text2TextGenerationPipeline, Text2TextGenerationPipeline,
TextClassificationPipeline, TextClassificationPipeline,
TextGenerationPipeline, TextGenerationPipeline,

View File

@ -468,6 +468,13 @@ explained here: https://github.com/rusty1s/pytorch_scatter.
""" """
# docstyle-ignore
PANDAS_IMPORT_ERROR = """
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
"""
def requires_datasets(obj): def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available(): if not is_datasets_available():
@ -522,6 +529,12 @@ def requires_protobuf(obj):
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name)) raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
def requires_pandas(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_pandas_available():
raise ImportError(PANDAS_IMPORT_ERROR.format(name))
def requires_scatter(obj): def requires_scatter(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_scatter_available(): if not is_scatter_available():

View File

@ -87,41 +87,49 @@ class TapasConfig(PretrainedConfig):
Importance weight for the regression loss. Importance weight for the regression loss.
use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`): use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to normalize the answer loss by the maximum of the predicted and expected value. Whether to normalize the answer loss by the maximum of the predicted and expected value.
huber_loss_delta: (:obj:`float`, `optional`): huber_loss_delta (:obj:`float`, `optional`):
Delta parameter used to calculate the regression loss. Delta parameter used to calculate the regression loss.
temperature: (:obj:`float`, `optional`, defaults to 1.0): temperature (:obj:`float`, `optional`, defaults to 1.0):
Value used to control (OR change) the skewness of cell logits probabilities. Value used to control (OR change) the skewness of cell logits probabilities.
aggregation_temperature: (:obj:`float`, `optional`, defaults to 1.0): aggregation_temperature (:obj:`float`, `optional`, defaults to 1.0):
Scales aggregation logits to control the skewness of probabilities. Scales aggregation logits to control the skewness of probabilities.
use_gumbel_for_cells: (:obj:`bool`, `optional`, defaults to :obj:`False`): use_gumbel_for_cells (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to apply Gumbel-Softmax to cell selection. Whether to apply Gumbel-Softmax to cell selection.
use_gumbel_for_aggregation: (:obj:`bool`, `optional`, defaults to :obj:`False`): use_gumbel_for_aggregation (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to apply Gumbel-Softmax to aggregation selection. Whether to apply Gumbel-Softmax to aggregation selection.
average_approximation_function: (:obj:`string`, `optional`, defaults to :obj:`"ratio"`): average_approximation_function (:obj:`string`, `optional`, defaults to :obj:`"ratio"`):
Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`, Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`,
:obj:`"first_order"` or :obj:`"second_order"`. :obj:`"first_order"` or :obj:`"second_order"`.
cell_selection_preference: (:obj:`float`, `optional`): cell_selection_preference (:obj:`float`, `optional`):
Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE" aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
operator) is higher than this hyperparameter, then aggregation is predicted for an example. operator) is higher than this hyperparameter, then aggregation is predicted for an example.
answer_loss_cutoff: (:obj:`float`, `optional`): answer_loss_cutoff (:obj:`float`, `optional`):
Ignore examples with answer loss larger than cutoff. Ignore examples with answer loss larger than cutoff.
max_num_rows: (:obj:`int`, `optional`, defaults to 64): max_num_rows (:obj:`int`, `optional`, defaults to 64):
Maximum number of rows. Maximum number of rows.
max_num_columns: (:obj:`int`, `optional`, defaults to 32): max_num_columns (:obj:`int`, `optional`, defaults to 32):
Maximum number of columns. Maximum number of columns.
average_logits_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`False`): average_logits_per_cell (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to average logits per cell. Whether to average logits per cell.
select_one_column: (:obj:`bool`, `optional`, defaults to :obj:`True`): select_one_column (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to constrain the model to only select cells from a single column. Whether to constrain the model to only select cells from a single column.
allow_empty_column_selection: (:obj:`bool`, `optional`, defaults to :obj:`False`): allow_empty_column_selection (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to allow not to select any column. Whether to allow not to select any column.
init_cell_selection_weights_to_zero: (:obj:`bool`, `optional`, defaults to :obj:`False`): init_cell_selection_weights_to_zero (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%. Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
reset_position_index_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`True`): reset_position_index_per_cell (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to restart position indexes at every cell (i.e. use relative position embeddings). Whether to restart position indexes at every cell (i.e. use relative position embeddings).
disable_per_token_loss: (:obj:`bool`, `optional`, defaults to :obj:`False`): disable_per_token_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to disable any (strong or weak) supervision on cells. Whether to disable any (strong or weak) supervision on cells.
aggregation_labels (:obj:`Dict[int, label]`, `optional`):
The aggregation labels used to aggregate the results. For example, the WTQ models have the following
aggregation labels: :obj:`{0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}`
no_aggregation_label_index (:obj:`int`, `optional`):
If the aggregation labels are defined and one of these labels represents "No aggregation", this should be
set to its index. For example, the WTQ models have the "NONE" aggregation label at index 0, so that value
should be set to 0 for these models.
Example:: Example::
@ -174,6 +182,8 @@ class TapasConfig(PretrainedConfig):
init_cell_selection_weights_to_zero=False, init_cell_selection_weights_to_zero=False,
reset_position_index_per_cell=True, reset_position_index_per_cell=True,
disable_per_token_loss=False, disable_per_token_loss=False,
aggregation_labels=None,
no_aggregation_label_index=None,
**kwargs **kwargs
): ):
@ -217,3 +227,10 @@ class TapasConfig(PretrainedConfig):
self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
self.reset_position_index_per_cell = reset_position_index_per_cell self.reset_position_index_per_cell = reset_position_index_per_cell
self.disable_per_token_loss = disable_per_token_loss self.disable_per_token_loss = disable_per_token_loss
# Aggregation hyperparameters
self.aggregation_labels = aggregation_labels
self.no_aggregation_label_index = no_aggregation_label_index
if isinstance(self.aggregation_labels, dict):
self.aggregation_labels = {int(k): v for k, v in aggregation_labels.items()}

View File

@ -1905,12 +1905,14 @@ class TapasTokenizer(PreTrainedTokenizer):
this threshold will be selected. this threshold will be selected.
Returns: Returns:
:obj:`tuple` comprising various elements depending on the inputs: predicted_answer_coordinates :obj:`tuple` comprising various elements depending on the inputs:
(``List[List[[tuple]]`` of length ``batch_size``): Predicted answer coordinates as a list of lists of
tuples. Each element in the list contains the predicted answer coordinates of a single example in the - predicted_answer_coordinates (``List[List[[tuple]]`` of length ``batch_size``): Predicted answer
batch, as a list of tuples. Each tuple is a cell, i.e. (row index, column index). coordinates as a list of lists of tuples. Each element in the list contains the predicted answer
predicted_aggregation_indices (`optional`, returned when ``logits_aggregation`` is provided) ``List[int]`` coordinates of a single example in the batch, as a list of tuples. Each tuple is a cell, i.e. (row index,
of length ``batch_size``: Predicted aggregation operator indices of the aggregation head. column index).
- predicted_aggregation_indices (``List[int]``of length ``batch_size``, `optional`, returned when
``logits_aggregation`` is provided): Predicted aggregation operator indices of the aggregation head.
""" """
# input data is of type float32 # input data is of type float32
# np.log(np.finfo(np.float32).max) = 88.72284 # np.log(np.finfo(np.float32).max) = 88.72284
@ -1969,11 +1971,11 @@ class TapasTokenizer(PreTrainedTokenizer):
answer_coordinates = sorted(answer_coordinates) answer_coordinates = sorted(answer_coordinates)
predicted_answer_coordinates.append(answer_coordinates) predicted_answer_coordinates.append(answer_coordinates)
output = predicted_answer_coordinates output = (predicted_answer_coordinates,)
if logits_agg is not None: if logits_agg is not None:
predicted_aggregation_indices = logits_agg.argmax(dim=-1) predicted_aggregation_indices = logits_agg.argmax(dim=-1)
output = (output, predicted_aggregation_indices.tolist()) output = (predicted_answer_coordinates, predicted_aggregation_indices.tolist())
return output return output

View File

@ -12,8 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections
import csv import csv
import json import json
import os import os
@ -32,7 +31,7 @@ import numpy as np
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available from .file_utils import add_end_docstrings, is_tf_available, is_torch_available, requires_pandas
from .modelcard import ModelCard from .modelcard import ModelCard
from .models.auto.configuration_auto import AutoConfig from .models.auto.configuration_auto import AutoConfig
from .models.auto.tokenization_auto import AutoTokenizer from .models.auto.tokenization_auto import AutoTokenizer
@ -68,6 +67,7 @@ if is_torch_available():
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
@ -75,6 +75,7 @@ if is_torch_available():
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoModelForTableQuestionAnswering,
AutoModelForTokenClassification, AutoModelForTokenClassification,
) )
@ -2058,6 +2059,274 @@ class QuestionAnsweringPipeline(Pipeline):
} }
class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
"""
Handles arguments for the TableQuestionAnsweringPipeline
"""
def __call__(self, table=None, query=None, sequential=False, padding=True, truncation=True):
# Returns tqa_pipeline_inputs of shape:
# [
# {"table": pd.DataFrame, "query": List[str]},
# ...,
# {"table": pd.DataFrame, "query" : List[str]}
# ]
requires_pandas(self)
import pandas as pd
if table is None:
raise ValueError("Keyword argument `table` cannot be None.")
elif query is None:
if isinstance(table, dict) and table.get("query") is not None and table.get("table") is not None:
tqa_pipeline_inputs = [table]
elif isinstance(table, list) and len(table) > 0:
if not all(isinstance(d, dict) for d in table):
raise ValueError(
f"Keyword argument `table` should be a list of dict, but is {(type(d) for d in table)}"
)
if table[0].get("query") is not None and table[0].get("table") is not None:
tqa_pipeline_inputs = table
else:
raise ValueError(
f"If keyword argument `table` is a list of dictionaries, each dictionary should have a `table` "
f"and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
)
else:
raise ValueError(
f"Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
f"is {type(table)})"
)
else:
tqa_pipeline_inputs = [{"table": table, "query": query}]
for tqa_pipeline_input in tqa_pipeline_inputs:
if not isinstance(tqa_pipeline_input["table"], pd.DataFrame):
if tqa_pipeline_input["table"] is None:
raise ValueError("Table cannot be None.")
tqa_pipeline_input["table"] = pd.DataFrame(tqa_pipeline_input["table"])
return tqa_pipeline_inputs, sequential, padding, truncation
@add_end_docstrings(PIPELINE_INIT_ARGS)
class TableQuestionAnsweringPipeline(Pipeline):
"""
Table Question Answering pipeline using a :obj:`ModelForTableQuestionAnswering`. This pipeline is only available in
PyTorch.
This tabular question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the
following task identifier: :obj:`"table-question-answering"`.
The models that this pipeline can use are models that have been fine-tuned on a tabular question answering task.
See the up-to-date list of available models on `huggingface.co/models
<https://huggingface.co/models?filter=table-question-answering>`__.
"""
default_input_names = "table,query"
def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, **kwargs):
super().__init__(*args, **kwargs)
self._args_parser = args_parser
if self.framework == "tf":
raise ValueError("The TableQuestionAnsweringPipeline is only available in PyTorch.")
self.check_model_type(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING)
self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
getattr(self.model.config, "num_aggregation_labels")
)
def batch_inference(self, **inputs):
with torch.no_grad():
return self.model(**inputs)
def sequential_inference(self, **inputs):
"""
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
handle conversational query related to a table.
"""
with torch.no_grad():
all_logits = []
all_aggregations = []
prev_answers = None
batch_size = inputs["input_ids"].shape[0]
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
token_type_ids = inputs["token_type_ids"].to(self.device)
token_type_ids_example = None
for index in range(batch_size):
# If sequences have already been processed, the token type IDs will be created according to the previous
# answer.
if prev_answers is not None:
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
for i in range(model_labels.shape[0]):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
row_id = token_type_ids_example[:, 2].tolist()[i] - 1
if row_id >= 0 and col_id >= 0 and segment_id == 1:
model_labels[i] = int(prev_answers[(col_id, row_id)])
token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)
input_ids_example = input_ids[index]
attention_mask_example = attention_mask[index] # shape (seq_len,)
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
outputs = self.model(
input_ids=input_ids_example.unsqueeze(0),
attention_mask=attention_mask_example.unsqueeze(0),
token_type_ids=token_type_ids_example.unsqueeze(0),
)
logits = outputs.logits
if self.aggregate:
all_aggregations.append(outputs.logits_aggregation)
all_logits.append(logits)
dist_per_token = torch.distributions.Bernoulli(logits=logits)
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
dist_per_token.probs.device
)
coords_to_probs = collections.defaultdict(list)
for i, p in enumerate(probabilities.squeeze().tolist()):
segment_id = token_type_ids_example[:, 0].tolist()[i]
col = token_type_ids_example[:, 1].tolist()[i] - 1
row = token_type_ids_example[:, 2].tolist()[i] - 1
if col >= 0 and row >= 0 and segment_id == 1:
coords_to_probs[(col, row)].append(p)
prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
logits_batch = torch.cat(tuple(all_logits), 0)
return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
def __call__(self, *args, **kwargs):
r"""
Answers queries according to a table. The pipeline accepts several types of inputs which are detailed below:
- ``pipeline(table, query)``
- ``pipeline(table, [query])``
- ``pipeline(table=table, query=query)``
- ``pipeline(table=table, query=[query])``
- ``pipeline({"table": table, "query": query})``
- ``pipeline({"table": table, "query": [query]})``
- ``pipeline([{"table": table, "query": query}, {"table": table, "query": query}])``
The :obj:`table` argument should be a dict or a DataFrame built from that dict, containing the whole table:
Example::
data = {
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
"age": ["56", "45", "59"],
"number of movies": ["87", "53", "69"],
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
}
This dictionary can be passed in as such, or can be converted to a pandas DataFrame:
Example::
import pandas as pd
table = pd.DataFrame.from_dict(data)
Args:
table (:obj:`pd.DataFrame` or :obj:`Dict`):
Pandas DataFrame or dictionary that will be converted to a DataFrame containing all the table values.
See above for an example of dictionary.
query (:obj:`str` or :obj:`List[str]`):
Query or list of queries that will be sent to the model alongside the table.
sequential (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the
inference to be done sequentially to extract relations within sequences, given their conversational
nature.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.TapasTruncationStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate row by row, removing rows from the table.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
Return:
A dictionary or a list of dictionaries containing results: Each result is a dictionary with the following
keys:
- **answer** (:obj:`str`) -- The answer of the query given the table. If there is an aggregator, the answer
will be preceded by :obj:`AGGREGATOR >`.
- **coordinates** (:obj:`List[Tuple[int, int]]`) -- Coordinates of the cells of the answers.
- **cells** (:obj:`List[str]`) -- List of strings made up of the answer cell values.
- **aggregator** (:obj:`str`) -- If the model has an aggregator, this returns the aggregator.
"""
pipeline_inputs, sequential, padding, truncation = self._args_parser(*args, **kwargs)
batched_answers = []
for pipeline_input in pipeline_inputs:
table, query = pipeline_input["table"], pipeline_input["query"]
inputs = self.tokenizer(
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
)
outputs = self.sequential_inference(**inputs) if sequential else self.batch_inference(**inputs)
if self.aggregate:
logits, logits_agg = outputs[:2]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach(), logits_agg)
answer_coordinates_batch, agg_predictions = predictions
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
no_agg_label_index = self.model.config.no_aggregation_label_index
aggregators_prefix = {
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
}
else:
logits = outputs[0]
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach())
answer_coordinates_batch = predictions[0]
aggregators = {}
aggregators_prefix = {}
answers = []
for index, coordinates in enumerate(answer_coordinates_batch):
cells = [table.iat[coordinate] for coordinate in coordinates]
aggregator = aggregators.get(index, "")
aggregator_prefix = aggregators_prefix.get(index, "")
answer = {
"answer": aggregator_prefix + ", ".join(cells),
"coordinates": coordinates,
"cells": [table.iat[coordinate] for coordinate in coordinates],
}
if aggregator:
answer["aggregator"] = aggregator
answers.append(answer)
batched_answers.append(answers if len(answers) > 1 else answers[0])
return batched_answers if len(batched_answers) > 1 else batched_answers[0]
@add_end_docstrings(PIPELINE_INIT_ARGS) @add_end_docstrings(PIPELINE_INIT_ARGS)
class SummarizationPipeline(Pipeline): class SummarizationPipeline(Pipeline):
""" """
@ -2752,6 +3021,18 @@ SUPPORTED_TASKS = {
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"}, "model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
}, },
}, },
"table-question-answering": {
"impl": TableQuestionAnsweringPipeline,
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
"tf": None,
"default": {
"model": {
"pt": "nielsr/tapas-base-finetuned-wtq",
"tokenizer": "nielsr/tapas-base-finetuned-wtq",
"tf": "nielsr/tapas-base-finetuned-wtq",
},
},
},
"fill-mask": { "fill-mask": {
"impl": FillMaskPipeline, "impl": FillMaskPipeline,
"tf": TFAutoModelForMaskedLM if is_tf_available() else None, "tf": TFAutoModelForMaskedLM if is_tf_available() else None,
@ -3006,6 +3287,12 @@ def pipeline(
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. " "Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
"Trying to load the model with Tensorflow." "Trying to load the model with Tensorflow."
) )
if model_class is None:
raise ValueError(
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
)
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs) model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
if task == "translation" and model.config.task_specific_params: if task == "translation" and model.config.task_specific_params:
for key in model.config.task_specific_params: for key in model.config.task_specific_params:

View File

@ -172,6 +172,19 @@ def require_torch(test_case):
return test_case return test_case
def require_torch_scatter(test_case):
"""
Decorator marking a test that requires PyTorch scatter.
These tests are skipped when PyTorch scatter isn't installed.
"""
if not _scatter_available:
return unittest.skip("test requires PyTorch scatter")(test_case)
else:
return test_case
def require_tf(test_case): def require_tf(test_case):
""" """
Decorator marking a test that requires TensorFlow. Decorator marking a test that requires TensorFlow.

View File

@ -0,0 +1,234 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.pipelines import Pipeline, pipeline
from transformers.testing_utils import require_pandas, require_torch, require_torch_scatter, slow
from .test_pipelines_common import CustomInputPipelineCommonMixin
@require_torch_scatter
@require_torch
@require_pandas
class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
pipeline_task = "table-question-answering"
pipeline_running_kwargs = {
"padding": "max_length",
}
small_models = [
"lysandre/tiny-tapas-random-wtq",
"lysandre/tiny-tapas-random-sqa",
]
large_models = ["nielsr/tapas-base-finetuned-wtq"] # Models tested with the @slow decorator
valid_inputs = [
{
"table": {
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
"age": ["56", "45", "59"],
"number of movies": ["87", "53", "69"],
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
},
"query": "how many movies has george clooney played in?",
},
{
"table": {
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
"age": ["56", "45", "59"],
"number of movies": ["87", "53", "69"],
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
},
"query": ["how many movies has george clooney played in?", "how old is he?", "what's his date of birth?"],
},
{
"table": {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
},
"query": [
"What repository has the largest number of stars?",
"Given that the numbers of stars defines if a repository is active, what repository is the most active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
],
},
]
def _test_pipeline(self, table_querier: Pipeline):
output_keys = {"answer", "coordinates", "cells"}
valid_inputs = self.valid_inputs
invalid_inputs = [
{"query": "What does it do with empty context ?", "table": ""},
{"query": "What does it do with empty context ?", "table": None},
]
self.assertIsNotNone(table_querier)
mono_result = table_querier(valid_inputs[0])
self.assertIsInstance(mono_result, dict)
for key in output_keys:
self.assertIn(key, mono_result)
multi_result = table_querier(valid_inputs)
self.assertIsInstance(multi_result, list)
for result in multi_result:
self.assertIsInstance(result, (list, dict))
for result in multi_result:
if isinstance(result, list):
for _result in result:
for key in output_keys:
self.assertIn(key, _result)
else:
for key in output_keys:
self.assertIn(key, result)
for bad_input in invalid_inputs:
self.assertRaises(ValueError, table_querier, bad_input)
self.assertRaises(ValueError, table_querier, invalid_inputs)
def test_aggregation(self):
table_querier = pipeline(
"table-question-answering",
model="lysandre/tiny-tapas-random-wtq",
tokenizer="lysandre/tiny-tapas-random-wtq",
)
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
mono_result = table_querier(self.valid_inputs[0])
multi_result = table_querier(self.valid_inputs)
self.assertIn("aggregator", mono_result)
for result in multi_result:
if isinstance(result, list):
for _result in result:
self.assertIn("aggregator", _result)
else:
self.assertIn("aggregator", result)
def test_aggregation_with_sequential(self):
table_querier = pipeline(
"table-question-answering",
model="lysandre/tiny-tapas-random-wtq",
tokenizer="lysandre/tiny-tapas-random-wtq",
)
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
mono_result = table_querier(self.valid_inputs[0], sequential=True)
multi_result = table_querier(self.valid_inputs, sequential=True)
self.assertIn("aggregator", mono_result)
for result in multi_result:
if isinstance(result, list):
for _result in result:
self.assertIn("aggregator", _result)
else:
self.assertIn("aggregator", result)
def test_sequential(self):
table_querier = pipeline(
"table-question-answering",
model="lysandre/tiny-tapas-random-sqa",
tokenizer="lysandre/tiny-tapas-random-sqa",
)
sequential_mono_result_0 = table_querier(self.valid_inputs[0], sequential=True)
sequential_mono_result_1 = table_querier(self.valid_inputs[1], sequential=True)
sequential_multi_result = table_querier(self.valid_inputs, sequential=True)
mono_result_0 = table_querier(self.valid_inputs[0])
mono_result_1 = table_querier(self.valid_inputs[1])
multi_result = table_querier(self.valid_inputs)
# First valid input has a single question, the dict should be equal
self.assertDictEqual(sequential_mono_result_0, mono_result_0)
# Second valid input has several questions, the questions following the first one should not be equal
self.assertNotEqual(sequential_mono_result_1, mono_result_1)
# Assert that we get the same results when passing in several sequences.
for index, (sequential_multi, multi) in enumerate(zip(sequential_multi_result, multi_result)):
if index == 0:
self.assertDictEqual(sequential_multi, multi)
else:
self.assertNotEqual(sequential_multi, multi)
@slow
def test_integration_wtq(self):
tqa_pipeline = pipeline("table-question-answering")
data = {
"Repository": ["Transformers", "Datasets", "Tokenizers"],
"Stars": ["36542", "4512", "3934"],
"Contributors": ["651", "77", "34"],
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
}
queries = [
"What repository has the largest number of stars?",
"Given that the numbers of stars defines if a repository is active, what repository is the most active?",
"What is the number of repositories?",
"What is the average number of stars?",
"What is the total amount of stars?",
]
results = tqa_pipeline(data, queries)
expected_results = [
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"]},
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"]},
{
"answer": "Transformers, Datasets, Tokenizers",
"coordinates": [(0, 0), (1, 0), (2, 0)],
"cells": ["Transformers", "Datasets", "Tokenizers"],
},
{
"answer": "36542, 4512, 3934",
"coordinates": [(0, 1), (1, 1), (2, 1)],
"cells": ["36542", "4512", "3934"],
},
{
"answer": "36542, 4512, 3934",
"coordinates": [(0, 1), (1, 1), (2, 1)],
"cells": ["36542", "4512", "3934"],
},
]
self.assertListEqual(results, expected_results)
@slow
def test_integration_sqa(self):
tqa_pipeline = pipeline(
"table-question-answering",
model="nielsr/tapas-base-finetuned-sqa",
tokenizer="nielsr/tapas-base-finetuned-sqa",
)
data = {
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
"Age": ["56", "45", "59"],
"Number of movies": ["87", "53", "69"],
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
}
queries = ["How many movies has George Clooney played in?", "How old is he?", "What's his date of birth?"]
results = tqa_pipeline(data, queries, sequential=True)
expected_results = [
{"answer": "69", "coordinates": [(2, 2)], "cells": ["69"]},
{"answer": "59", "coordinates": [(2, 1)], "cells": ["59"]},
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
]
self.assertListEqual(results, expected_results)