import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from datasets import load_from_disk
from transformers import BigBirdTokenizerFast
CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"}
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"']))
def get_sub_answers(answers, begin=0, end=None):
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
def expand_to_aliases(given_answers, make_sub_answers=False):
if make_sub_answers:
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
given_answers = (
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
answers = []
for answer in given_answers:
alias = answer.replace("_", " ").lower()
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias)
answers.append(" ".join(alias.split()).strip())
return set(answers)
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k)
best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k)
widths = best_end_idx[:, None] - best_start_idx[None, :]
mask = jnp.logical_or(widths < 0, widths > max_size)
scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask)
best_score = jnp.argmax(scores).item()
return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]
def format_dataset(sample):
question = sample["question"]["text"]
context = sample["document"]["tokens"]["token"]
is_html = sample["document"]["tokens"]["is_html"]
long_answers = sample["annotations"]["long_answer"]
short_answers = sample["annotations"]["short_answers"]
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])
# 0 - No ; 1 - Yes
for answer in sample["annotations"]["yes_no_answer"]:
if answer == 0 or answer == 1:
return {
"question": question,
"context": context_string,
"short": [],
"long": [],
"category": "no" if answer == 0 else "yes",
short_targets = []
for s in short_answers:
short_targets = list(set(short_targets))
long_targets = []
for s in long_answers:
if s["start_token"] == -1:
answer = context[s["start_token"] : s["end_token"]]
html = is_html[s["start_token"] : s["end_token"]]
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
if new_answer not in long_targets:
category = "long_short" if len(short_targets + long_targets) > 0 else "null"
return {
"question": question,
"context": context_string,
"short": short_targets,
"long": long_targets,
"category": category,
def main():
dataset = load_from_disk("natural-questions-validation")
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096)
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null")
model_id = "vasudevgupta/flax-bigbird-natural-questions"
model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id)
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id)
def forward(*args, **kwargs):
start_logits, end_logits, pooled_logits = model(*args, **kwargs)
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1)
def evaluate(example):
# encode question and context so that they are separated by a tokenizer.sep_token and cut at max_length
inputs = tokenizer(
start_scores, end_scores, category = forward(**inputs)
predicted_category = CATEGORY_MAPPING[category.item()]
example["targets"] = example["long"] + example["short"]
if example["category"] in ["yes", "no", "null"]:
example["targets"] = [example["category"]]
example["has_tgt"] = example["category"] != "null"
# Now target can be: "yes", "no", "null", "list of long & short answers"
if predicted_category in ["yes", "no", "null"]:
example["output"] = [predicted_category]
example["match"] = example["output"] == example["targets"]
example["has_pred"] = predicted_category != "null"
return example
max_size = 38 if predicted_category == "short" else 1024
start_score, end_score = get_best_valid_start_end_idx(
start_scores[0], end_scores[0], top_k=8, max_size=max_size
input_ids = inputs["input_ids"][0].tolist()
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])]
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
predictions = expand_to_aliases(example["output"])
# some preprocessing to both prediction and answer
answers = {"".join(a.split()) for a in answers}
predictions = {"".join(p.split()) for p in predictions}
predictions = {s for s in predictions if s not in ["``", "''", "`", "'"]}
# if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0
example["has_pred"] = predicted_category != "null" and len(predictions) > 0
return example
short_validation_dataset = short_validation_dataset.map(evaluate)
total = len(short_validation_dataset)
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
print("EM score:", (matched / total) * 100, "%")
if __name__ == "__main__":