TableQuestionAnsweringPipeline (#9145)
* AutoModelForTableQuestionAnswering * TableQuestionAnsweringPipeline * Apply suggestions from Patrick's code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Sylvain and Patrick comments * Better PyTorch/TF error message * Add integration tests * Argument Handler naming Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com> * Fix docs to appease the documentation gods Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
07384baf7a
commit
1c1a2ffbff
|
@ -34,6 +34,7 @@ There are two categories of pipeline abstractions to be aware about:
|
||||||
- :class:`~transformers.TranslationPipeline`
|
- :class:`~transformers.TranslationPipeline`
|
||||||
- :class:`~transformers.ZeroShotClassificationPipeline`
|
- :class:`~transformers.ZeroShotClassificationPipeline`
|
||||||
- :class:`~transformers.Text2TextGenerationPipeline`
|
- :class:`~transformers.Text2TextGenerationPipeline`
|
||||||
|
- :class:`~transformers.TableQuestionAnsweringPipeline`
|
||||||
|
|
||||||
The pipeline abstraction
|
The pipeline abstraction
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
@ -91,6 +92,13 @@ SummarizationPipeline
|
||||||
:special-members: __call__
|
:special-members: __call__
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
|
TableQuestionAnsweringPipeline
|
||||||
|
=======================================================================================================================
|
||||||
|
|
||||||
|
.. autoclass:: transformers.TableQuestionAnsweringPipeline
|
||||||
|
:special-members: __call__
|
||||||
|
|
||||||
|
|
||||||
TextClassificationPipeline
|
TextClassificationPipeline
|
||||||
=======================================================================================================================
|
=======================================================================================================================
|
||||||
|
|
||||||
|
|
|
@ -190,6 +190,7 @@ from .pipelines import (
|
||||||
PipelineDataFormat,
|
PipelineDataFormat,
|
||||||
QuestionAnsweringPipeline,
|
QuestionAnsweringPipeline,
|
||||||
SummarizationPipeline,
|
SummarizationPipeline,
|
||||||
|
TableQuestionAnsweringPipeline,
|
||||||
Text2TextGenerationPipeline,
|
Text2TextGenerationPipeline,
|
||||||
TextClassificationPipeline,
|
TextClassificationPipeline,
|
||||||
TextGenerationPipeline,
|
TextGenerationPipeline,
|
||||||
|
|
|
@ -468,6 +468,13 @@ explained here: https://github.com/rusty1s/pytorch_scatter.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# docstyle-ignore
|
||||||
|
PANDAS_IMPORT_ERROR = """
|
||||||
|
{0} requires the pandas library but it was not found in your environment. You can install it with pip as
|
||||||
|
explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def requires_datasets(obj):
|
def requires_datasets(obj):
|
||||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
if not is_datasets_available():
|
if not is_datasets_available():
|
||||||
|
@ -522,6 +529,12 @@ def requires_protobuf(obj):
|
||||||
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
|
raise ImportError(PROTOBUF_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
|
def requires_pandas(obj):
|
||||||
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
|
if not is_pandas_available():
|
||||||
|
raise ImportError(PANDAS_IMPORT_ERROR.format(name))
|
||||||
|
|
||||||
|
|
||||||
def requires_scatter(obj):
|
def requires_scatter(obj):
|
||||||
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
|
||||||
if not is_scatter_available():
|
if not is_scatter_available():
|
||||||
|
|
|
@ -87,41 +87,49 @@ class TapasConfig(PretrainedConfig):
|
||||||
Importance weight for the regression loss.
|
Importance weight for the regression loss.
|
||||||
use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
use_normalized_answer_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to normalize the answer loss by the maximum of the predicted and expected value.
|
Whether to normalize the answer loss by the maximum of the predicted and expected value.
|
||||||
huber_loss_delta: (:obj:`float`, `optional`):
|
huber_loss_delta (:obj:`float`, `optional`):
|
||||||
Delta parameter used to calculate the regression loss.
|
Delta parameter used to calculate the regression loss.
|
||||||
temperature: (:obj:`float`, `optional`, defaults to 1.0):
|
temperature (:obj:`float`, `optional`, defaults to 1.0):
|
||||||
Value used to control (OR change) the skewness of cell logits probabilities.
|
Value used to control (OR change) the skewness of cell logits probabilities.
|
||||||
aggregation_temperature: (:obj:`float`, `optional`, defaults to 1.0):
|
aggregation_temperature (:obj:`float`, `optional`, defaults to 1.0):
|
||||||
Scales aggregation logits to control the skewness of probabilities.
|
Scales aggregation logits to control the skewness of probabilities.
|
||||||
use_gumbel_for_cells: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
use_gumbel_for_cells (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to apply Gumbel-Softmax to cell selection.
|
Whether to apply Gumbel-Softmax to cell selection.
|
||||||
use_gumbel_for_aggregation: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
use_gumbel_for_aggregation (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to apply Gumbel-Softmax to aggregation selection.
|
Whether to apply Gumbel-Softmax to aggregation selection.
|
||||||
average_approximation_function: (:obj:`string`, `optional`, defaults to :obj:`"ratio"`):
|
average_approximation_function (:obj:`string`, `optional`, defaults to :obj:`"ratio"`):
|
||||||
Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`,
|
Method to calculate the expected average of cells in the weak supervision case. One of :obj:`"ratio"`,
|
||||||
:obj:`"first_order"` or :obj:`"second_order"`.
|
:obj:`"first_order"` or :obj:`"second_order"`.
|
||||||
cell_selection_preference: (:obj:`float`, `optional`):
|
cell_selection_preference (:obj:`float`, `optional`):
|
||||||
Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
|
Preference for cell selection in ambiguous cases. Only applicable in case of weak supervision for
|
||||||
aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
|
aggregation (WTQ, WikiSQL). If the total mass of the aggregation probabilities (excluding the "NONE"
|
||||||
operator) is higher than this hyperparameter, then aggregation is predicted for an example.
|
operator) is higher than this hyperparameter, then aggregation is predicted for an example.
|
||||||
answer_loss_cutoff: (:obj:`float`, `optional`):
|
answer_loss_cutoff (:obj:`float`, `optional`):
|
||||||
Ignore examples with answer loss larger than cutoff.
|
Ignore examples with answer loss larger than cutoff.
|
||||||
max_num_rows: (:obj:`int`, `optional`, defaults to 64):
|
max_num_rows (:obj:`int`, `optional`, defaults to 64):
|
||||||
Maximum number of rows.
|
Maximum number of rows.
|
||||||
max_num_columns: (:obj:`int`, `optional`, defaults to 32):
|
max_num_columns (:obj:`int`, `optional`, defaults to 32):
|
||||||
Maximum number of columns.
|
Maximum number of columns.
|
||||||
average_logits_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
average_logits_per_cell (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to average logits per cell.
|
Whether to average logits per cell.
|
||||||
select_one_column: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
select_one_column (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether to constrain the model to only select cells from a single column.
|
Whether to constrain the model to only select cells from a single column.
|
||||||
allow_empty_column_selection: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
allow_empty_column_selection (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to allow not to select any column.
|
Whether to allow not to select any column.
|
||||||
init_cell_selection_weights_to_zero: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
init_cell_selection_weights_to_zero (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
|
Whether to initialize cell selection weights to 0 so that the initial probabilities are 50%.
|
||||||
reset_position_index_per_cell: (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
reset_position_index_per_cell (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||||
Whether to restart position indexes at every cell (i.e. use relative position embeddings).
|
Whether to restart position indexes at every cell (i.e. use relative position embeddings).
|
||||||
disable_per_token_loss: (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
disable_per_token_loss (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to disable any (strong or weak) supervision on cells.
|
Whether to disable any (strong or weak) supervision on cells.
|
||||||
|
aggregation_labels (:obj:`Dict[int, label]`, `optional`):
|
||||||
|
The aggregation labels used to aggregate the results. For example, the WTQ models have the following
|
||||||
|
aggregation labels: :obj:`{0: "NONE", 1: "SUM", 2: "AVERAGE", 3: "COUNT"}`
|
||||||
|
no_aggregation_label_index (:obj:`int`, `optional`):
|
||||||
|
If the aggregation labels are defined and one of these labels represents "No aggregation", this should be
|
||||||
|
set to its index. For example, the WTQ models have the "NONE" aggregation label at index 0, so that value
|
||||||
|
should be set to 0 for these models.
|
||||||
|
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
|
@ -174,6 +182,8 @@ class TapasConfig(PretrainedConfig):
|
||||||
init_cell_selection_weights_to_zero=False,
|
init_cell_selection_weights_to_zero=False,
|
||||||
reset_position_index_per_cell=True,
|
reset_position_index_per_cell=True,
|
||||||
disable_per_token_loss=False,
|
disable_per_token_loss=False,
|
||||||
|
aggregation_labels=None,
|
||||||
|
no_aggregation_label_index=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -217,3 +227,10 @@ class TapasConfig(PretrainedConfig):
|
||||||
self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
|
self.init_cell_selection_weights_to_zero = init_cell_selection_weights_to_zero
|
||||||
self.reset_position_index_per_cell = reset_position_index_per_cell
|
self.reset_position_index_per_cell = reset_position_index_per_cell
|
||||||
self.disable_per_token_loss = disable_per_token_loss
|
self.disable_per_token_loss = disable_per_token_loss
|
||||||
|
|
||||||
|
# Aggregation hyperparameters
|
||||||
|
self.aggregation_labels = aggregation_labels
|
||||||
|
self.no_aggregation_label_index = no_aggregation_label_index
|
||||||
|
|
||||||
|
if isinstance(self.aggregation_labels, dict):
|
||||||
|
self.aggregation_labels = {int(k): v for k, v in aggregation_labels.items()}
|
||||||
|
|
|
@ -1905,12 +1905,14 @@ class TapasTokenizer(PreTrainedTokenizer):
|
||||||
this threshold will be selected.
|
this threshold will be selected.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:obj:`tuple` comprising various elements depending on the inputs: predicted_answer_coordinates
|
:obj:`tuple` comprising various elements depending on the inputs:
|
||||||
(``List[List[[tuple]]`` of length ``batch_size``): Predicted answer coordinates as a list of lists of
|
|
||||||
tuples. Each element in the list contains the predicted answer coordinates of a single example in the
|
- predicted_answer_coordinates (``List[List[[tuple]]`` of length ``batch_size``): Predicted answer
|
||||||
batch, as a list of tuples. Each tuple is a cell, i.e. (row index, column index).
|
coordinates as a list of lists of tuples. Each element in the list contains the predicted answer
|
||||||
predicted_aggregation_indices (`optional`, returned when ``logits_aggregation`` is provided) ``List[int]``
|
coordinates of a single example in the batch, as a list of tuples. Each tuple is a cell, i.e. (row index,
|
||||||
of length ``batch_size``: Predicted aggregation operator indices of the aggregation head.
|
column index).
|
||||||
|
- predicted_aggregation_indices (``List[int]``of length ``batch_size``, `optional`, returned when
|
||||||
|
``logits_aggregation`` is provided): Predicted aggregation operator indices of the aggregation head.
|
||||||
"""
|
"""
|
||||||
# input data is of type float32
|
# input data is of type float32
|
||||||
# np.log(np.finfo(np.float32).max) = 88.72284
|
# np.log(np.finfo(np.float32).max) = 88.72284
|
||||||
|
@ -1969,11 +1971,11 @@ class TapasTokenizer(PreTrainedTokenizer):
|
||||||
answer_coordinates = sorted(answer_coordinates)
|
answer_coordinates = sorted(answer_coordinates)
|
||||||
predicted_answer_coordinates.append(answer_coordinates)
|
predicted_answer_coordinates.append(answer_coordinates)
|
||||||
|
|
||||||
output = predicted_answer_coordinates
|
output = (predicted_answer_coordinates,)
|
||||||
|
|
||||||
if logits_agg is not None:
|
if logits_agg is not None:
|
||||||
predicted_aggregation_indices = logits_agg.argmax(dim=-1)
|
predicted_aggregation_indices = logits_agg.argmax(dim=-1)
|
||||||
output = (output, predicted_aggregation_indices.tolist())
|
output = (predicted_answer_coordinates, predicted_aggregation_indices.tolist())
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -12,8 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import collections
|
||||||
|
|
||||||
import csv
|
import csv
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -32,7 +31,7 @@ import numpy as np
|
||||||
|
|
||||||
from .configuration_utils import PretrainedConfig
|
from .configuration_utils import PretrainedConfig
|
||||||
from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features
|
from .data import SquadExample, SquadFeatures, squad_convert_examples_to_features
|
||||||
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available
|
from .file_utils import add_end_docstrings, is_tf_available, is_torch_available, requires_pandas
|
||||||
from .modelcard import ModelCard
|
from .modelcard import ModelCard
|
||||||
from .models.auto.configuration_auto import AutoConfig
|
from .models.auto.configuration_auto import AutoConfig
|
||||||
from .models.auto.tokenization_auto import AutoTokenizer
|
from .models.auto.tokenization_auto import AutoTokenizer
|
||||||
|
@ -68,6 +67,7 @@ if is_torch_available():
|
||||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
|
MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING,
|
||||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
AutoModel,
|
AutoModel,
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
|
@ -75,6 +75,7 @@ if is_torch_available():
|
||||||
AutoModelForQuestionAnswering,
|
AutoModelForQuestionAnswering,
|
||||||
AutoModelForSeq2SeqLM,
|
AutoModelForSeq2SeqLM,
|
||||||
AutoModelForSequenceClassification,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoModelForTableQuestionAnswering,
|
||||||
AutoModelForTokenClassification,
|
AutoModelForTokenClassification,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -2058,6 +2059,274 @@ class QuestionAnsweringPipeline(Pipeline):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
|
||||||
|
"""
|
||||||
|
Handles arguments for the TableQuestionAnsweringPipeline
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(self, table=None, query=None, sequential=False, padding=True, truncation=True):
|
||||||
|
# Returns tqa_pipeline_inputs of shape:
|
||||||
|
# [
|
||||||
|
# {"table": pd.DataFrame, "query": List[str]},
|
||||||
|
# ...,
|
||||||
|
# {"table": pd.DataFrame, "query" : List[str]}
|
||||||
|
# ]
|
||||||
|
requires_pandas(self)
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
if table is None:
|
||||||
|
raise ValueError("Keyword argument `table` cannot be None.")
|
||||||
|
elif query is None:
|
||||||
|
if isinstance(table, dict) and table.get("query") is not None and table.get("table") is not None:
|
||||||
|
tqa_pipeline_inputs = [table]
|
||||||
|
elif isinstance(table, list) and len(table) > 0:
|
||||||
|
if not all(isinstance(d, dict) for d in table):
|
||||||
|
raise ValueError(
|
||||||
|
f"Keyword argument `table` should be a list of dict, but is {(type(d) for d in table)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if table[0].get("query") is not None and table[0].get("table") is not None:
|
||||||
|
tqa_pipeline_inputs = table
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"If keyword argument `table` is a list of dictionaries, each dictionary should have a `table` "
|
||||||
|
f"and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
|
||||||
|
f"is {type(table)})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tqa_pipeline_inputs = [{"table": table, "query": query}]
|
||||||
|
|
||||||
|
for tqa_pipeline_input in tqa_pipeline_inputs:
|
||||||
|
if not isinstance(tqa_pipeline_input["table"], pd.DataFrame):
|
||||||
|
if tqa_pipeline_input["table"] is None:
|
||||||
|
raise ValueError("Table cannot be None.")
|
||||||
|
|
||||||
|
tqa_pipeline_input["table"] = pd.DataFrame(tqa_pipeline_input["table"])
|
||||||
|
|
||||||
|
return tqa_pipeline_inputs, sequential, padding, truncation
|
||||||
|
|
||||||
|
|
||||||
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
|
class TableQuestionAnsweringPipeline(Pipeline):
|
||||||
|
"""
|
||||||
|
Table Question Answering pipeline using a :obj:`ModelForTableQuestionAnswering`. This pipeline is only available in
|
||||||
|
PyTorch.
|
||||||
|
|
||||||
|
This tabular question answering pipeline can currently be loaded from :func:`~transformers.pipeline` using the
|
||||||
|
following task identifier: :obj:`"table-question-answering"`.
|
||||||
|
|
||||||
|
The models that this pipeline can use are models that have been fine-tuned on a tabular question answering task.
|
||||||
|
See the up-to-date list of available models on `huggingface.co/models
|
||||||
|
<https://huggingface.co/models?filter=table-question-answering>`__.
|
||||||
|
"""
|
||||||
|
|
||||||
|
default_input_names = "table,query"
|
||||||
|
|
||||||
|
def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._args_parser = args_parser
|
||||||
|
|
||||||
|
if self.framework == "tf":
|
||||||
|
raise ValueError("The TableQuestionAnsweringPipeline is only available in PyTorch.")
|
||||||
|
|
||||||
|
self.check_model_type(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING)
|
||||||
|
|
||||||
|
self.aggregate = bool(getattr(self.model.config, "aggregation_labels")) and bool(
|
||||||
|
getattr(self.model.config, "num_aggregation_labels")
|
||||||
|
)
|
||||||
|
|
||||||
|
def batch_inference(self, **inputs):
|
||||||
|
with torch.no_grad():
|
||||||
|
return self.model(**inputs)
|
||||||
|
|
||||||
|
def sequential_inference(self, **inputs):
|
||||||
|
"""
|
||||||
|
Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
|
||||||
|
handle conversational query related to a table.
|
||||||
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
|
all_logits = []
|
||||||
|
all_aggregations = []
|
||||||
|
prev_answers = None
|
||||||
|
batch_size = inputs["input_ids"].shape[0]
|
||||||
|
|
||||||
|
input_ids = inputs["input_ids"].to(self.device)
|
||||||
|
attention_mask = inputs["attention_mask"].to(self.device)
|
||||||
|
token_type_ids = inputs["token_type_ids"].to(self.device)
|
||||||
|
token_type_ids_example = None
|
||||||
|
|
||||||
|
for index in range(batch_size):
|
||||||
|
# If sequences have already been processed, the token type IDs will be created according to the previous
|
||||||
|
# answer.
|
||||||
|
if prev_answers is not None:
|
||||||
|
prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
|
||||||
|
model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)
|
||||||
|
|
||||||
|
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
|
||||||
|
for i in range(model_labels.shape[0]):
|
||||||
|
segment_id = token_type_ids_example[:, 0].tolist()[i]
|
||||||
|
col_id = token_type_ids_example[:, 1].tolist()[i] - 1
|
||||||
|
row_id = token_type_ids_example[:, 2].tolist()[i] - 1
|
||||||
|
|
||||||
|
if row_id >= 0 and col_id >= 0 and segment_id == 1:
|
||||||
|
model_labels[i] = int(prev_answers[(col_id, row_id)])
|
||||||
|
|
||||||
|
token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)
|
||||||
|
|
||||||
|
input_ids_example = input_ids[index]
|
||||||
|
attention_mask_example = attention_mask[index] # shape (seq_len,)
|
||||||
|
token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
|
||||||
|
outputs = self.model(
|
||||||
|
input_ids=input_ids_example.unsqueeze(0),
|
||||||
|
attention_mask=attention_mask_example.unsqueeze(0),
|
||||||
|
token_type_ids=token_type_ids_example.unsqueeze(0),
|
||||||
|
)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
if self.aggregate:
|
||||||
|
all_aggregations.append(outputs.logits_aggregation)
|
||||||
|
|
||||||
|
all_logits.append(logits)
|
||||||
|
|
||||||
|
dist_per_token = torch.distributions.Bernoulli(logits=logits)
|
||||||
|
probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
|
||||||
|
dist_per_token.probs.device
|
||||||
|
)
|
||||||
|
|
||||||
|
coords_to_probs = collections.defaultdict(list)
|
||||||
|
for i, p in enumerate(probabilities.squeeze().tolist()):
|
||||||
|
segment_id = token_type_ids_example[:, 0].tolist()[i]
|
||||||
|
col = token_type_ids_example[:, 1].tolist()[i] - 1
|
||||||
|
row = token_type_ids_example[:, 2].tolist()[i] - 1
|
||||||
|
if col >= 0 and row >= 0 and segment_id == 1:
|
||||||
|
coords_to_probs[(col, row)].append(p)
|
||||||
|
|
||||||
|
prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
|
||||||
|
|
||||||
|
logits_batch = torch.cat(tuple(all_logits), 0)
|
||||||
|
|
||||||
|
return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
|
||||||
|
|
||||||
|
def __call__(self, *args, **kwargs):
|
||||||
|
r"""
|
||||||
|
Answers queries according to a table. The pipeline accepts several types of inputs which are detailed below:
|
||||||
|
|
||||||
|
- ``pipeline(table, query)``
|
||||||
|
- ``pipeline(table, [query])``
|
||||||
|
- ``pipeline(table=table, query=query)``
|
||||||
|
- ``pipeline(table=table, query=[query])``
|
||||||
|
- ``pipeline({"table": table, "query": query})``
|
||||||
|
- ``pipeline({"table": table, "query": [query]})``
|
||||||
|
- ``pipeline([{"table": table, "query": query}, {"table": table, "query": query}])``
|
||||||
|
|
||||||
|
The :obj:`table` argument should be a dict or a DataFrame built from that dict, containing the whole table:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||||
|
"age": ["56", "45", "59"],
|
||||||
|
"number of movies": ["87", "53", "69"],
|
||||||
|
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||||
|
}
|
||||||
|
|
||||||
|
This dictionary can be passed in as such, or can be converted to a pandas DataFrame:
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
table = pd.DataFrame.from_dict(data)
|
||||||
|
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table (:obj:`pd.DataFrame` or :obj:`Dict`):
|
||||||
|
Pandas DataFrame or dictionary that will be converted to a DataFrame containing all the table values.
|
||||||
|
See above for an example of dictionary.
|
||||||
|
query (:obj:`str` or :obj:`List[str]`):
|
||||||
|
Query or list of queries that will be sent to the model alongside the table.
|
||||||
|
sequential (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the
|
||||||
|
inference to be done sequentially to extract relations within sequences, given their conversational
|
||||||
|
nature.
|
||||||
|
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
|
||||||
|
Activates and controls padding. Accepts the following values:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
|
||||||
|
single sequence if provided).
|
||||||
|
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
||||||
|
maximum acceptable input length for the model if that argument is not provided.
|
||||||
|
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
||||||
|
different lengths).
|
||||||
|
|
||||||
|
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.TapasTruncationStrategy`, `optional`, defaults to :obj:`False`):
|
||||||
|
Activates and controls truncation. Accepts the following values:
|
||||||
|
|
||||||
|
* :obj:`True` or :obj:`'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument
|
||||||
|
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
|
||||||
|
provided. This will truncate row by row, removing rows from the table.
|
||||||
|
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
|
||||||
|
sequence lengths greater than the model maximum admissible input size).
|
||||||
|
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A dictionary or a list of dictionaries containing results: Each result is a dictionary with the following
|
||||||
|
keys:
|
||||||
|
|
||||||
|
- **answer** (:obj:`str`) -- The answer of the query given the table. If there is an aggregator, the answer
|
||||||
|
will be preceded by :obj:`AGGREGATOR >`.
|
||||||
|
- **coordinates** (:obj:`List[Tuple[int, int]]`) -- Coordinates of the cells of the answers.
|
||||||
|
- **cells** (:obj:`List[str]`) -- List of strings made up of the answer cell values.
|
||||||
|
- **aggregator** (:obj:`str`) -- If the model has an aggregator, this returns the aggregator.
|
||||||
|
"""
|
||||||
|
pipeline_inputs, sequential, padding, truncation = self._args_parser(*args, **kwargs)
|
||||||
|
batched_answers = []
|
||||||
|
for pipeline_input in pipeline_inputs:
|
||||||
|
table, query = pipeline_input["table"], pipeline_input["query"]
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
table, query, return_tensors=self.framework, truncation="drop_rows_to_fit", padding=padding
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.sequential_inference(**inputs) if sequential else self.batch_inference(**inputs)
|
||||||
|
|
||||||
|
if self.aggregate:
|
||||||
|
logits, logits_agg = outputs[:2]
|
||||||
|
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach(), logits_agg)
|
||||||
|
answer_coordinates_batch, agg_predictions = predictions
|
||||||
|
aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
|
||||||
|
|
||||||
|
no_agg_label_index = self.model.config.no_aggregation_label_index
|
||||||
|
aggregators_prefix = {
|
||||||
|
i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
logits = outputs[0]
|
||||||
|
predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits.detach())
|
||||||
|
answer_coordinates_batch = predictions[0]
|
||||||
|
aggregators = {}
|
||||||
|
aggregators_prefix = {}
|
||||||
|
|
||||||
|
answers = []
|
||||||
|
for index, coordinates in enumerate(answer_coordinates_batch):
|
||||||
|
cells = [table.iat[coordinate] for coordinate in coordinates]
|
||||||
|
aggregator = aggregators.get(index, "")
|
||||||
|
aggregator_prefix = aggregators_prefix.get(index, "")
|
||||||
|
answer = {
|
||||||
|
"answer": aggregator_prefix + ", ".join(cells),
|
||||||
|
"coordinates": coordinates,
|
||||||
|
"cells": [table.iat[coordinate] for coordinate in coordinates],
|
||||||
|
}
|
||||||
|
if aggregator:
|
||||||
|
answer["aggregator"] = aggregator
|
||||||
|
|
||||||
|
answers.append(answer)
|
||||||
|
batched_answers.append(answers if len(answers) > 1 else answers[0])
|
||||||
|
return batched_answers if len(batched_answers) > 1 else batched_answers[0]
|
||||||
|
|
||||||
|
|
||||||
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
@add_end_docstrings(PIPELINE_INIT_ARGS)
|
||||||
class SummarizationPipeline(Pipeline):
|
class SummarizationPipeline(Pipeline):
|
||||||
"""
|
"""
|
||||||
|
@ -2752,6 +3021,18 @@ SUPPORTED_TASKS = {
|
||||||
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
"model": {"pt": "distilbert-base-cased-distilled-squad", "tf": "distilbert-base-cased-distilled-squad"},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"table-question-answering": {
|
||||||
|
"impl": TableQuestionAnsweringPipeline,
|
||||||
|
"pt": AutoModelForTableQuestionAnswering if is_torch_available() else None,
|
||||||
|
"tf": None,
|
||||||
|
"default": {
|
||||||
|
"model": {
|
||||||
|
"pt": "nielsr/tapas-base-finetuned-wtq",
|
||||||
|
"tokenizer": "nielsr/tapas-base-finetuned-wtq",
|
||||||
|
"tf": "nielsr/tapas-base-finetuned-wtq",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"fill-mask": {
|
"fill-mask": {
|
||||||
"impl": FillMaskPipeline,
|
"impl": FillMaskPipeline,
|
||||||
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
|
"tf": TFAutoModelForMaskedLM if is_tf_available() else None,
|
||||||
|
@ -3006,6 +3287,12 @@ def pipeline(
|
||||||
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
"Model might be a PyTorch model (ending with `.bin`) but PyTorch is not available. "
|
||||||
"Trying to load the model with Tensorflow."
|
"Trying to load the model with Tensorflow."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if model_class is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline using {framework} framework, but this framework is not supported by this pipeline."
|
||||||
|
)
|
||||||
|
|
||||||
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
model = model_class.from_pretrained(model, config=config, revision=revision, **model_kwargs)
|
||||||
if task == "translation" and model.config.task_specific_params:
|
if task == "translation" and model.config.task_specific_params:
|
||||||
for key in model.config.task_specific_params:
|
for key in model.config.task_specific_params:
|
||||||
|
|
|
@ -172,6 +172,19 @@ def require_torch(test_case):
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_scatter(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires PyTorch scatter.
|
||||||
|
|
||||||
|
These tests are skipped when PyTorch scatter isn't installed.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if not _scatter_available:
|
||||||
|
return unittest.skip("test requires PyTorch scatter")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_tf(test_case):
|
def require_tf(test_case):
|
||||||
"""
|
"""
|
||||||
Decorator marking a test that requires TensorFlow.
|
Decorator marking a test that requires TensorFlow.
|
||||||
|
|
|
@ -0,0 +1,234 @@
|
||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers.pipelines import Pipeline, pipeline
|
||||||
|
from transformers.testing_utils import require_pandas, require_torch, require_torch_scatter, slow
|
||||||
|
|
||||||
|
from .test_pipelines_common import CustomInputPipelineCommonMixin
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_scatter
|
||||||
|
@require_torch
|
||||||
|
@require_pandas
|
||||||
|
class TQAPipelineTests(CustomInputPipelineCommonMixin, unittest.TestCase):
|
||||||
|
pipeline_task = "table-question-answering"
|
||||||
|
pipeline_running_kwargs = {
|
||||||
|
"padding": "max_length",
|
||||||
|
}
|
||||||
|
small_models = [
|
||||||
|
"lysandre/tiny-tapas-random-wtq",
|
||||||
|
"lysandre/tiny-tapas-random-sqa",
|
||||||
|
]
|
||||||
|
large_models = ["nielsr/tapas-base-finetuned-wtq"] # Models tested with the @slow decorator
|
||||||
|
valid_inputs = [
|
||||||
|
{
|
||||||
|
"table": {
|
||||||
|
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||||
|
"age": ["56", "45", "59"],
|
||||||
|
"number of movies": ["87", "53", "69"],
|
||||||
|
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||||
|
},
|
||||||
|
"query": "how many movies has george clooney played in?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"table": {
|
||||||
|
"actors": ["brad pitt", "leonardo di caprio", "george clooney"],
|
||||||
|
"age": ["56", "45", "59"],
|
||||||
|
"number of movies": ["87", "53", "69"],
|
||||||
|
"date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||||
|
},
|
||||||
|
"query": ["how many movies has george clooney played in?", "how old is he?", "what's his date of birth?"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"table": {
|
||||||
|
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||||
|
"Stars": ["36542", "4512", "3934"],
|
||||||
|
"Contributors": ["651", "77", "34"],
|
||||||
|
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||||
|
},
|
||||||
|
"query": [
|
||||||
|
"What repository has the largest number of stars?",
|
||||||
|
"Given that the numbers of stars defines if a repository is active, what repository is the most active?",
|
||||||
|
"What is the number of repositories?",
|
||||||
|
"What is the average number of stars?",
|
||||||
|
"What is the total amount of stars?",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
def _test_pipeline(self, table_querier: Pipeline):
|
||||||
|
output_keys = {"answer", "coordinates", "cells"}
|
||||||
|
valid_inputs = self.valid_inputs
|
||||||
|
invalid_inputs = [
|
||||||
|
{"query": "What does it do with empty context ?", "table": ""},
|
||||||
|
{"query": "What does it do with empty context ?", "table": None},
|
||||||
|
]
|
||||||
|
self.assertIsNotNone(table_querier)
|
||||||
|
|
||||||
|
mono_result = table_querier(valid_inputs[0])
|
||||||
|
self.assertIsInstance(mono_result, dict)
|
||||||
|
|
||||||
|
for key in output_keys:
|
||||||
|
self.assertIn(key, mono_result)
|
||||||
|
|
||||||
|
multi_result = table_querier(valid_inputs)
|
||||||
|
self.assertIsInstance(multi_result, list)
|
||||||
|
for result in multi_result:
|
||||||
|
self.assertIsInstance(result, (list, dict))
|
||||||
|
|
||||||
|
for result in multi_result:
|
||||||
|
if isinstance(result, list):
|
||||||
|
for _result in result:
|
||||||
|
for key in output_keys:
|
||||||
|
self.assertIn(key, _result)
|
||||||
|
else:
|
||||||
|
for key in output_keys:
|
||||||
|
self.assertIn(key, result)
|
||||||
|
for bad_input in invalid_inputs:
|
||||||
|
self.assertRaises(ValueError, table_querier, bad_input)
|
||||||
|
self.assertRaises(ValueError, table_querier, invalid_inputs)
|
||||||
|
|
||||||
|
def test_aggregation(self):
|
||||||
|
table_querier = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model="lysandre/tiny-tapas-random-wtq",
|
||||||
|
tokenizer="lysandre/tiny-tapas-random-wtq",
|
||||||
|
)
|
||||||
|
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
|
||||||
|
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
|
||||||
|
|
||||||
|
mono_result = table_querier(self.valid_inputs[0])
|
||||||
|
multi_result = table_querier(self.valid_inputs)
|
||||||
|
|
||||||
|
self.assertIn("aggregator", mono_result)
|
||||||
|
|
||||||
|
for result in multi_result:
|
||||||
|
if isinstance(result, list):
|
||||||
|
for _result in result:
|
||||||
|
self.assertIn("aggregator", _result)
|
||||||
|
else:
|
||||||
|
self.assertIn("aggregator", result)
|
||||||
|
|
||||||
|
def test_aggregation_with_sequential(self):
|
||||||
|
table_querier = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model="lysandre/tiny-tapas-random-wtq",
|
||||||
|
tokenizer="lysandre/tiny-tapas-random-wtq",
|
||||||
|
)
|
||||||
|
self.assertIsInstance(table_querier.model.config.aggregation_labels, dict)
|
||||||
|
self.assertIsInstance(table_querier.model.config.no_aggregation_label_index, int)
|
||||||
|
|
||||||
|
mono_result = table_querier(self.valid_inputs[0], sequential=True)
|
||||||
|
multi_result = table_querier(self.valid_inputs, sequential=True)
|
||||||
|
|
||||||
|
self.assertIn("aggregator", mono_result)
|
||||||
|
|
||||||
|
for result in multi_result:
|
||||||
|
if isinstance(result, list):
|
||||||
|
for _result in result:
|
||||||
|
self.assertIn("aggregator", _result)
|
||||||
|
else:
|
||||||
|
self.assertIn("aggregator", result)
|
||||||
|
|
||||||
|
def test_sequential(self):
|
||||||
|
table_querier = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model="lysandre/tiny-tapas-random-sqa",
|
||||||
|
tokenizer="lysandre/tiny-tapas-random-sqa",
|
||||||
|
)
|
||||||
|
sequential_mono_result_0 = table_querier(self.valid_inputs[0], sequential=True)
|
||||||
|
sequential_mono_result_1 = table_querier(self.valid_inputs[1], sequential=True)
|
||||||
|
sequential_multi_result = table_querier(self.valid_inputs, sequential=True)
|
||||||
|
mono_result_0 = table_querier(self.valid_inputs[0])
|
||||||
|
mono_result_1 = table_querier(self.valid_inputs[1])
|
||||||
|
multi_result = table_querier(self.valid_inputs)
|
||||||
|
|
||||||
|
# First valid input has a single question, the dict should be equal
|
||||||
|
self.assertDictEqual(sequential_mono_result_0, mono_result_0)
|
||||||
|
|
||||||
|
# Second valid input has several questions, the questions following the first one should not be equal
|
||||||
|
self.assertNotEqual(sequential_mono_result_1, mono_result_1)
|
||||||
|
|
||||||
|
# Assert that we get the same results when passing in several sequences.
|
||||||
|
for index, (sequential_multi, multi) in enumerate(zip(sequential_multi_result, multi_result)):
|
||||||
|
if index == 0:
|
||||||
|
self.assertDictEqual(sequential_multi, multi)
|
||||||
|
else:
|
||||||
|
self.assertNotEqual(sequential_multi, multi)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_integration_wtq(self):
|
||||||
|
tqa_pipeline = pipeline("table-question-answering")
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"Repository": ["Transformers", "Datasets", "Tokenizers"],
|
||||||
|
"Stars": ["36542", "4512", "3934"],
|
||||||
|
"Contributors": ["651", "77", "34"],
|
||||||
|
"Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
|
||||||
|
}
|
||||||
|
queries = [
|
||||||
|
"What repository has the largest number of stars?",
|
||||||
|
"Given that the numbers of stars defines if a repository is active, what repository is the most active?",
|
||||||
|
"What is the number of repositories?",
|
||||||
|
"What is the average number of stars?",
|
||||||
|
"What is the total amount of stars?",
|
||||||
|
]
|
||||||
|
|
||||||
|
results = tqa_pipeline(data, queries)
|
||||||
|
|
||||||
|
expected_results = [
|
||||||
|
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"]},
|
||||||
|
{"answer": "Transformers", "coordinates": [(0, 0)], "cells": ["Transformers"]},
|
||||||
|
{
|
||||||
|
"answer": "Transformers, Datasets, Tokenizers",
|
||||||
|
"coordinates": [(0, 0), (1, 0), (2, 0)],
|
||||||
|
"cells": ["Transformers", "Datasets", "Tokenizers"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer": "36542, 4512, 3934",
|
||||||
|
"coordinates": [(0, 1), (1, 1), (2, 1)],
|
||||||
|
"cells": ["36542", "4512", "3934"],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"answer": "36542, 4512, 3934",
|
||||||
|
"coordinates": [(0, 1), (1, 1), (2, 1)],
|
||||||
|
"cells": ["36542", "4512", "3934"],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
self.assertListEqual(results, expected_results)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_integration_sqa(self):
|
||||||
|
tqa_pipeline = pipeline(
|
||||||
|
"table-question-answering",
|
||||||
|
model="nielsr/tapas-base-finetuned-sqa",
|
||||||
|
tokenizer="nielsr/tapas-base-finetuned-sqa",
|
||||||
|
)
|
||||||
|
data = {
|
||||||
|
"Actors": ["Brad Pitt", "Leonardo Di Caprio", "George Clooney"],
|
||||||
|
"Age": ["56", "45", "59"],
|
||||||
|
"Number of movies": ["87", "53", "69"],
|
||||||
|
"Date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
|
||||||
|
}
|
||||||
|
queries = ["How many movies has George Clooney played in?", "How old is he?", "What's his date of birth?"]
|
||||||
|
results = tqa_pipeline(data, queries, sequential=True)
|
||||||
|
|
||||||
|
expected_results = [
|
||||||
|
{"answer": "69", "coordinates": [(2, 2)], "cells": ["69"]},
|
||||||
|
{"answer": "59", "coordinates": [(2, 1)], "cells": ["59"]},
|
||||||
|
{"answer": "28 november 1967", "coordinates": [(2, 3)], "cells": ["28 november 1967"]},
|
||||||
|
]
|
||||||
|
self.assertListEqual(results, expected_results)
|
Loading…
Reference in New Issue