transformers/examples/research_projects/tapex/wikisql_utils.py

258 lines
7.8 KiB
Python

# coding=utf-8
# Copyright 2022 The Microsoft, The Google and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import enum
import functools
import math
import re
# The following script is adapted from the script of TaPas.
# Original: https://github.com/google-research/tapas/master/wikisql_utils.py
from typing import Any, List
EMPTY_ANSWER = "none"
EMPTY_ANSWER_AGG = "none"
def _split_thousands(delimiter, value):
split = value.split(delimiter)
return len(split) > 1 and any((len(x) == 3 for x in split))
def convert_to_float(value):
"""Converts value to a float using a series of increasingly complex heuristics.
Args:
value: object that needs to be converted. Allowed types include
float/int/strings.
Returns:
A float interpretation of value.
Raises:
ValueError if the float conversion of value fails.
"""
if isinstance(value, float):
return value
if isinstance(value, int):
return float(value)
if not isinstance(value, str):
raise ValueError("Argument value is not a string. Can't parse it as float")
sanitized = value
try:
# Example: 1,000.7
if "." in sanitized and "," in sanitized:
return float(sanitized.replace(",", ""))
# 1,000
if "," in sanitized and _split_thousands(",", sanitized):
return float(sanitized.replace(",", ""))
# 5,5556
if "," in sanitized and sanitized.count(",") == 1 and not _split_thousands(",", sanitized):
return float(sanitized.replace(",", "."))
# 0.0.0.1
if sanitized.count(".") > 1:
return float(sanitized.replace(".", ""))
# 0,0,0,1
if sanitized.count(",") > 1:
return float(sanitized.replace(",", ""))
return float(sanitized)
except ValueError:
# Avoid adding the sanitized value in the error message.
raise ValueError("Unable to convert value to float")
def _normalize_float(answer):
if answer is None:
return None
try:
value = convert_to_float(answer)
if isinstance(value, float) and math.isnan(value):
return None
return value
except ValueError:
return answer.lower()
_TYPE_CONVERTER = {
"text": lambda x: x,
"real": convert_to_float,
}
class _Aggregation(enum.Enum):
"""Aggregations as defined by WikiSQL. Indexes match the data."""
NONE = 0
MAX = 1
MIN = 2
COUNT = 3
SUM = 4
AVERAGE = 5
class _Operator(enum.Enum):
"""The boolean operators used by WikiSQL. Indexes match the data."""
EQUALS = 0
GREATER = 1
LESSER = 2
@dataclasses.dataclass
class _Condition:
"""Represents an SQL where clauses (e.g A = "a" or B > 5)."""
column: str
operator: _Operator
cmp_value: Any
_TOKENIZER = re.compile(r"\w+|[^\w\s]+", re.UNICODE | re.MULTILINE | re.DOTALL)
def _normalize_for_match(x):
return list(_TOKENIZER.findall(x.lower()))
def _compare(operator, src, tgt):
if operator == _Operator.EQUALS:
return src == tgt
elif operator == _Operator.GREATER:
return src > tgt
elif operator == _Operator.LESSER:
return src < tgt
raise ValueError(f"Unknown operator: {operator}")
def _parse_value(table, column, cell_value):
"""Convert numeric values to floats and keeps everything else as string."""
types = table["types"]
return _TYPE_CONVERTER[types[column]](cell_value)
def _is_string(x):
return isinstance(x, str)
def _respect_conditions(table, row, conditions):
"""True if 'row' satisfies all 'conditions'."""
for cond in conditions:
table_value = row[cond.column]
cmp_value = _parse_value(table, cond.column, cond.cmp_value)
if _is_string(table_value) and _is_string(cmp_value):
table_value = _normalize_for_match(table_value)
cmp_value = _normalize_for_match(cmp_value)
if not isinstance(table_value, type(cmp_value)):
raise ValueError("Type difference {} != {}".format(type(table_value), type(cmp_value)))
if not _compare(cond.operator, table_value, cmp_value):
return False
return True
def _get_float_answer(table, answer_coordinates, aggregation_op):
"""Applies operation to produce reference float answer."""
if not answer_coordinates:
if aggregation_op == _Aggregation.COUNT:
return 0.0
else:
return EMPTY_ANSWER_AGG
# Count can support non numeric answers.
if aggregation_op == _Aggregation.COUNT:
return float(len(answer_coordinates))
# If we have just one answer, if float returns it or try a conversion.
values = [table["rows"][i][j] for (i, j) in answer_coordinates]
if len(answer_coordinates) == 1:
try:
return convert_to_float(values[0])
except ValueError as e:
if aggregation_op != _Aggregation.NONE:
raise e
if aggregation_op == _Aggregation.NONE:
return None
# Other aggregation only support numeric values. Bail out if we have strings.
if not all((isinstance(v, (int, float)) for v in values)):
return None
if aggregation_op == _Aggregation.SUM:
return float(sum(values))
elif aggregation_op == _Aggregation.AVERAGE:
return sum(values) / len(answer_coordinates)
else:
raise ValueError(f"Unknown aggregation: {aggregation_op}")
def _get_answer_coordinates(table, sql_query):
"""Retrieves references coordinates by executing SQL."""
# MAX and MIN are automatically supported by the model.
aggregation_op_index = sql_query["agg"]
if aggregation_op_index >= 3:
aggregation_op = _Aggregation(aggregation_op_index)
else:
aggregation_op = _Aggregation.NONE
target_column = sql_query["sel"]
conditions = [
_Condition(column, _Operator(operator), cmp_value)
for column, operator, cmp_value in zip(
sql_query["conds"]["column_index"], sql_query["conds"]["operator_index"], sql_query["conds"]["condition"]
)
]
indices = []
for row in range(len(table["rows"])):
if _respect_conditions(table, table["rows"][row], conditions):
indices.append((row, target_column))
if not indices:
return [], aggregation_op
if len(indices) == 1:
return indices, aggregation_op
# Parsing of MIN/MAX.
if aggregation_op_index in (1, 2):
operators = {2: min, 1: max}
values = [(table["rows"][i][j], index) for index, (i, j) in enumerate(indices)]
reduced = functools.reduce(operators[sql_query["agg"]], values)
ret = [indices[reduced[1]]]
return ret, _Aggregation.NONE
return indices, aggregation_op
def _get_answer_text(table, answer_coordinates, float_answer):
if float_answer is not None:
return [str(float_answer)]
return [str(table["real_rows"][r][c]) for r, c in answer_coordinates]
def retrieve_wikisql_query_answer_tapas(table, example) -> List:
answer_coordinates, aggregation_op = _get_answer_coordinates(table, example)
float_answer = _get_float_answer(table, answer_coordinates, aggregation_op)
answer_text = _get_answer_text(table, answer_coordinates, float_answer)
# keep the original data the same with TaPas
if len(answer_text) == 0:
answer_text = [EMPTY_ANSWER]
return answer_text