Add FlaxBigBird QuestionAnswering script (#12233)

* port bigbird script

* adapt script a bit

* change location

* adapt more

* save progress

* init commit

* style

* dataset script tested

* readme add
This commit is contained in:
Vasudev Gupta 2021-06-25 22:35:48 +05:30 committed by GitHub
parent 55bb4c06f7
commit 332a245861
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 976 additions and 0 deletions

View File

@ -0,0 +1,60 @@
Author: [@vasudevgupta7](https://github.com/vasudevgupta7)
## Intro
In this project, we fine-tuned [**BigBird**](https://arxiv.org/abs/2007.14062) on [**natural-questions**](https://huggingface.co/datasets/natural_questions) dataset for **question-answering** task on long documents. **BigBird**, is a **sparse-attention based transformer** which extends Transformer based models, such as BERT to much **longer sequences**.
Read more about BigBird at https://huggingface.co/blog/big-bird
## Fine-tuning
**Setup**
You need to install jax yourself by following the official docs ([refer this](https://github.com/google/jax#installation)). Other requirements for this project can be installed by running following command:
```shell
pip3 install -qr requirements.txt
```
**Download & prepare dataset**
The Natural Questions corpus contains questions from real users, and it requires QA systems to read and comprehend an entire Wikipedia article that may or may not contain the answer to the question. This corpus takes ~100 GB on disk. We have used HuggingFace datasets to download & process the dataset.
```shell
# just run following CMD
python3 prepare_natural_questions.py
# this will download the whole dataset from HuggingFace Hub & will make it ready for training
# this script takes ~3 hours to process the dataset
```
**Launch Training**
We have trained on Cloud's TPU v3-8. Each epoch took around 4.5 hours and the model got converged in just 2 epochs. You can see complete training args in [this script](bigbird_flax.py).
```shell
# just run following CMD
python3 train.py
# In case, you want to try hparams tuning, you can run wandb sweep
wandb sweep --project=bigbird sweep_flax.yaml
wandb agent <agent-id-obtained-by-above-CMD>
```
## Evaluation
Our evaluation script is different from the original script and we are evaluating sequences with length up to 4096 for simplicity. We managed to get the **EM score of ~55.2** using our evaluation script.
```shell
# download validation-dataset first
mkdir natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation
wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation
# simply run following command
python3 evaluate.py
```
You can find our checkpoint on HuggingFace Hub ([see this](https://huggingface.co/vasudevgupta/flax-bigbird-natural-questions)). In case you are interested in PyTorch BigBird fine-tuning, you can refer to [this repositary](https://github.com/vasudevgupta7/bigbird).

View File

@ -0,0 +1,321 @@
import json
import os
from dataclasses import dataclass
from functools import partial
from typing import Callable
from tqdm.auto import tqdm
import flax.linen as nn
import jax
import jax.numpy as jnp
import joblib
import optax
import wandb
from flax import jax_utils, struct, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
"""
BigBirdForQuestionAnswering with CLS Head over the top for predicting category
This way we can load its weights with FlaxBigBirdForQuestionAnswering
"""
config: BigBirdConfig
dtype: jnp.dtype = jnp.float32
add_pooling_layer: bool = True
def setup(self):
super().setup()
self.cls = nn.Dense(5, dtype=self.dtype)
def __call__(self, *args, **kwargs):
outputs = super().__call__(*args, **kwargs)
cls_out = self.cls(outputs[2])
return outputs[:2] + (cls_out,)
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
module_class = FlaxBigBirdForNaturalQuestionsModule
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
def cross_entropy(logits, labels, reduction=None):
"""
Args:
logits: bsz, seqlen, vocab_size
labels: bsz, seqlen
"""
vocab_size = logits.shape[-1]
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
logits = jax.nn.log_softmax(logits, axis=-1)
loss = -jnp.sum(labels * logits, axis=-1)
if reduction is not None:
loss = reduction(loss)
return loss
cross_entropy = partial(cross_entropy, reduction=jnp.mean)
start_loss = cross_entropy(start_logits, start_labels)
end_loss = cross_entropy(end_logits, end_labels)
pooled_loss = cross_entropy(pooled_logits, pooler_labels)
return (start_loss + end_loss + pooled_loss) / 3
@dataclass
class Args:
model_id: str = "google/bigbird-roberta-base"
logging_steps: int = 3000
save_steps: int = 10500
block_size: int = 128
num_random_blocks: int = 3
batch_size_per_device: int = 1
max_epochs: int = 5
# tx_args
lr: float = 3e-5
init_lr: float = 0.0
warmup_steps: int = 20000
weight_decay: float = 0.0095
save_dir: str = "bigbird-roberta-natural-questions"
base_dir: str = "training-expt"
tr_data_path: str = "data/nq-training.jsonl"
val_data_path: str = "data/nq-validation.jsonl"
def __post_init__(self):
os.makedirs(self.base_dir, exist_ok=True)
self.save_dir = os.path.join(self.base_dir, self.save_dir)
self.batch_size = self.batch_size_per_device * jax.device_count()
@dataclass
class DataCollator:
pad_id: int
max_length: int = 4096 # no dynamic padding on TPUs
def __call__(self, batch):
batch = self.collate_fn(batch)
batch = jax.tree_map(shard, batch)
return batch
def collate_fn(self, features):
input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
batch = {
"input_ids": jnp.array(input_ids, dtype=jnp.int32),
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
}
return batch
def fetch_inputs(self, input_ids: list):
inputs = [self._fetch_inputs(ids) for ids in input_ids]
return zip(*inputs)
def _fetch_inputs(self, input_ids: list):
attention_mask = [1 for _ in range(len(input_ids))]
while len(input_ids) < self.max_length:
input_ids.append(self.pad_id)
attention_mask.append(0)
return input_ids, attention_mask
def get_batched_dataset(dataset, batch_size, seed=None):
if seed is not None:
dataset = dataset.shuffle(seed=seed)
for i in range(len(dataset) // batch_size):
batch = dataset[i * batch_size : (i + 1) * batch_size]
yield dict(batch)
@partial(jax.pmap, axis_name="batch")
def train_step(state, drp_rng, **model_inputs):
def loss_fn(params):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
start_logits, end_logits, pooled_logits = outputs
return state.loss_fn(
start_logits,
start_labels,
end_logits,
end_labels,
pooled_logits,
pooled_labels,
)
drp_rng, new_drp_rng = jax.random.split(drp_rng)
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
grads = jax.lax.pmean(grads, "batch")
state = state.apply_gradients(grads=grads)
return state, metrics, new_drp_rng
@partial(jax.pmap, axis_name="batch")
def val_step(state, **model_inputs):
start_labels = model_inputs.pop("start_labels")
end_labels = model_inputs.pop("end_labels")
pooled_labels = model_inputs.pop("pooled_labels")
outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
start_logits, end_logits, pooled_logits = outputs
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
return metrics
class TrainState(train_state.TrainState):
loss_fn: Callable = struct.field(pytree_node=False)
@dataclass
class Trainer:
args: Args
data_collator: Callable
train_step_fn: Callable
val_step_fn: Callable
model_save_fn: Callable
logger: wandb
scheduler_fn: Callable = None
def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
params = model.params
state = TrainState.create(
apply_fn=model.__call__,
params=params,
tx=tx,
loss_fn=calculate_loss_for_nq,
)
if ckpt_dir is not None:
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
"warmup_steps": args.warmup_steps,
"num_train_steps": num_train_steps,
"weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)
state = train_state.TrainState(
step=step,
apply_fn=model.__call__,
params=params,
tx=tx,
opt_state=opt_state,
)
self.args = args
self.data_collator = data_collator
self.scheduler_fn = lr
model.params = params
state = jax_utils.replicate(state)
return state
def train(self, state, tr_dataset, val_dataset):
args = self.args
total = len(tr_dataset) // args.batch_size
rng = jax.random.PRNGKey(0)
drp_rng = jax.random.split(rng, jax.device_count())
for epoch in range(args.max_epochs):
running_loss = jnp.array(0, dtype=jnp.float32)
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
i = 0
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
batch = self.data_collator(batch)
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
if i % args.logging_steps == 0:
state_step = jax_utils.unreplicate(state.step)
tr_loss = running_loss.item() / i
lr = self.scheduler_fn(state_step - 1)
eval_loss = self.evaluate(state, val_dataset)
logging_dict = dict(
step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item()
)
tqdm.write(str(logging_dict))
self.logger.log(logging_dict, commit=True)
if i % args.save_steps == 0:
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
def evaluate(self, state, dataset):
dataloader = get_batched_dataset(dataset, self.args.batch_size)
total = len(dataset) // self.args.batch_size
running_loss = jnp.array(0, dtype=jnp.float32)
i = 0
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
batch = self.data_collator(batch)
metrics = self.val_step_fn(state, **batch)
running_loss += jax_utils.unreplicate(metrics["loss"])
i += 1
return running_loss / i
def save_checkpoint(self, save_dir, state):
state = jax_utils.unreplicate(state)
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
self.model_save_fn(save_dir, params=state.params)
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
f.write(to_bytes(state.opt_state))
joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
json.dump({"step": state.step.item()}, f)
print("DONE")
def restore_checkpoint(save_dir, state):
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
params = from_bytes(state.params, f.read())
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
opt_state = from_bytes(state.opt_state, f.read())
args = joblib.load(os.path.join(save_dir, "args.joblib"))
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
training_state = json.load(f)
step = training_state["step"]
print("DONE")
return params, opt_state, step, args, data_collator
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
decay_steps = num_train_steps - warmup_steps
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
return lr
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
def weight_decay_mask(params):
params = traverse_util.flatten_dict(params)
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
return traverse_util.unflatten_dict(mask)
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
return tx, lr

View File

@ -0,0 +1,165 @@
from datasets import load_from_disk
import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from transformers import BigBirdTokenizerFast
CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"}
PUNCTUATION_SET_TO_EXCLUDE = set("".join(["", "", "´", "`", ".", ",", "-", '"']))
def get_sub_answers(answers, begin=0, end=None):
return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1]
def expand_to_aliases(given_answers, make_sub_answers=False):
if make_sub_answers:
# if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word
# *e.g.* if the correct answer contains a prefix such as "the", or "a"
given_answers = (
given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1)
)
answers = []
for answer in given_answers:
alias = answer.replace("_", " ").lower()
alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias)
answers.append(" ".join(alias.split()).strip())
return set(answers)
def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100):
best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k)
best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k)
widths = best_end_idx[:, None] - best_start_idx[None, :]
mask = jnp.logical_or(widths < 0, widths > max_size)
scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask)
best_score = jnp.argmax(scores).item()
return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k]
def format_dataset(sample):
question = sample["question"]["text"]
context = sample["document"]["tokens"]["token"]
is_html = sample["document"]["tokens"]["is_html"]
long_answers = sample["annotations"]["long_answer"]
short_answers = sample["annotations"]["short_answers"]
context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]])
# 0 - No ; 1 - Yes
for answer in sample["annotations"]["yes_no_answer"]:
if answer == 0 or answer == 1:
return {
"question": question,
"context": context_string,
"short": [],
"long": [],
"category": "no" if answer == 0 else "yes",
}
short_targets = []
for s in short_answers:
short_targets.extend(s["text"])
short_targets = list(set(short_targets))
long_targets = []
for s in long_answers:
if s["start_token"] == -1:
continue
answer = context[s["start_token"] : s["end_token"]]
html = is_html[s["start_token"] : s["end_token"]]
new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]])
if new_answer not in long_targets:
long_targets.append(new_answer)
category = "long_short" if len(short_targets + long_targets) > 0 else "null"
return {
"question": question,
"context": context_string,
"short": short_targets,
"long": long_targets,
"category": category,
}
def main():
dataset = load_from_disk("natural-questions-validation")
dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"])
print(dataset)
short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096)
short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null")
short_validation_dataset
model_id = "vasudevgupta/flax-bigbird-natural-questions"
model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id)
tokenizer = BigBirdTokenizerFast.from_pretrained(model_id)
@jax.jit
def forward(*args, **kwargs):
start_logits, end_logits, pooled_logits = model(*args, **kwargs)
return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1)
def evaluate(example):
# encode question and context so that they are seperated by a tokenizer.sep_token and cut at max_length
inputs = tokenizer(
example["question"],
example["context"],
return_tensors="jax",
max_length=4096,
padding="max_length",
truncation=True,
)
start_scores, end_scores, category = forward(**inputs)
predicted_category = CATEGORY_MAPPING[category.item()]
example["targets"] = example["long"] + example["short"]
if example["category"] in ["yes", "no", "null"]:
example["targets"] = [example["category"]]
example["has_tgt"] = example["category"] != "null"
# Now target can be: "yes", "no", "null", "list of long & short answers"
if predicted_category in ["yes", "no", "null"]:
example["output"] = [predicted_category]
example["match"] = example["output"] == example["targets"]
example["has_pred"] = predicted_category != "null"
return example
max_size = 38 if predicted_category == "short" else 1024
start_score, end_score = get_best_valid_start_end_idx(
start_scores[0], end_scores[0], top_k=8, max_size=max_size
)
input_ids = inputs["input_ids"][0].tolist()
example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])]
answers = expand_to_aliases(example["targets"], make_sub_answers=True)
predictions = expand_to_aliases(example["output"])
# some preprocessing to both prediction and answer
answers = set(["".join(a.split()) for a in answers])
predictions = set(["".join(p.split()) for p in predictions])
predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]])
# if there is a common element, it's a exact match
example["match"] = len(list(answers & predictions)) > 0
example["has_pred"] = predicted_category != "null" and len(predictions) > 0
return example
short_validation_dataset = short_validation_dataset.map(evaluate)
total = len(short_validation_dataset)
matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1))
print("EM score:", (matched / total) * 100, "%")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,330 @@
import os
import numpy as np
from tqdm import tqdm
import jsonlines
DOC_STRIDE = 2048
MAX_LENGTH = 4096
SEED = 42
PROCESS_TRAIN = os.environ.pop("PROCESS_TRAIN", "false")
CATEGORY_MAPPING = {"null": 0, "short": 1, "long": 2, "yes": 3, "no": 4}
def _get_single_answer(example):
def choose_first(answer, is_long_answer=False):
assert isinstance(answer, list)
if len(answer) == 1:
answer = answer[0]
return {k: [answer[k]] for k in answer} if is_long_answer else answer
for a in answer:
if is_long_answer:
a = {k: [a[k]] for k in a}
if len(a["start_token"]) > 0:
break
return a
answer = {"id": example["id"]}
annotation = example["annotations"]
yes_no_answer = annotation["yes_no_answer"]
if 0 in yes_no_answer or 1 in yes_no_answer:
answer["category"] = ["yes"] if 1 in yes_no_answer else ["no"]
answer["start_token"] = answer["end_token"] = []
answer["start_byte"] = answer["end_byte"] = []
answer["text"] = ["<cls>"]
else:
answer["category"] = ["short"]
out = choose_first(annotation["short_answers"])
if len(out["start_token"]) == 0:
# answer will be long if short is not available
answer["category"] = ["long"]
out = choose_first(annotation["long_answer"], is_long_answer=True)
out["text"] = []
answer.update(out)
# disregard some samples
if len(answer["start_token"]) > 1 or answer["start_token"] == answer["end_token"]:
answer["remove_it"] = True
else:
answer["remove_it"] = False
cols = ["start_token", "end_token", "start_byte", "end_byte", "text"]
if not all([isinstance(answer[k], list) for k in cols]):
raise ValueError("Issue in ID", example["id"])
return answer
def get_context_and_ans(example, assertion=False):
"""Gives new context after removing <html> & new answer tokens as per new context"""
answer = _get_single_answer(example)
# bytes are of no use
del answer["start_byte"]
del answer["end_byte"]
# handle yes_no answers explicitly
if answer["category"][0] in ["yes", "no"]: # category is list with one element
doc = example["document"]["tokens"]
context = []
for i in range(len(doc["token"])):
if not doc["is_html"][i]:
context.append(doc["token"][i])
return {
"context": " ".join(context),
"answer": {
"start_token": -100, # ignore index in cross-entropy
"end_token": -100, # ignore index in cross-entropy
"category": answer["category"],
"span": answer["category"], # extra
},
}
# later, help in removing all no answers
if answer["start_token"] == [-1]:
return {
"context": "None",
"answer": {
"start_token": -1,
"end_token": -1,
"category": "null",
"span": "None", # extra
},
}
# handling normal samples
cols = ["start_token", "end_token"]
answer.update({k: answer[k][0] if len(answer[k]) > 0 else answer[k] for k in cols}) # e.g. [10] == 10
doc = example["document"]["tokens"]
start_token = answer["start_token"]
end_token = answer["end_token"]
context = []
for i in range(len(doc["token"])):
if not doc["is_html"][i]:
context.append(doc["token"][i])
else:
if answer["start_token"] > i:
start_token -= 1
if answer["end_token"] > i:
end_token -= 1
new = " ".join(context[start_token:end_token])
# checking above code
if assertion:
"""checking if above code is working as expected for all the samples"""
is_html = doc["is_html"][answer["start_token"] : answer["end_token"]]
old = doc["token"][answer["start_token"] : answer["end_token"]]
old = " ".join([old[i] for i in range(len(old)) if not is_html[i]])
if new != old:
print("ID:", example["id"])
print("New:", new, end="\n")
print("Old:", old, end="\n\n")
return {
"context": " ".join(context),
"answer": {
"start_token": start_token,
"end_token": end_token - 1, # this makes it inclusive
"category": answer["category"], # either long or short
"span": new, # extra
},
}
def get_strided_contexts_and_ans(example, tokenizer, doc_stride=2048, max_length=4096, assertion=True):
# overlap will be of doc_stride - q_len
out = get_context_and_ans(example, assertion=assertion)
answer = out["answer"]
# later, removing these samples
if answer["start_token"] == -1:
return {
"example_id": example["id"],
"input_ids": [[-1]],
"labels": {
"start_token": [-1],
"end_token": [-1],
"category": ["null"],
},
}
input_ids = tokenizer(example["question"]["text"], out["context"]).input_ids
q_len = input_ids.index(tokenizer.sep_token_id) + 1
# return yes/no
if answer["category"][0] in ["yes", "no"]: # category is list with one element
inputs = []
category = []
q_indices = input_ids[:q_len]
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
for i in doc_start_indices:
end_index = i + max_length - q_len
slice = input_ids[i:end_index]
inputs.append(q_indices + slice)
category.append(answer["category"][0])
if slice[-1] == tokenizer.sep_token_id:
break
return {
"example_id": example["id"],
"input_ids": inputs,
"labels": {
"start_token": [-100] * len(category),
"end_token": [-100] * len(category),
"category": category,
},
}
splitted_context = out["context"].split()
complete_end_token = splitted_context[answer["end_token"]]
answer["start_token"] = len(
tokenizer(
" ".join(splitted_context[: answer["start_token"]]),
add_special_tokens=False,
).input_ids
)
answer["end_token"] = len(
tokenizer(" ".join(splitted_context[: answer["end_token"]]), add_special_tokens=False).input_ids
)
answer["start_token"] += q_len
answer["end_token"] += q_len
# fixing end token
num_sub_tokens = len(tokenizer(complete_end_token, add_special_tokens=False).input_ids)
if num_sub_tokens > 1:
answer["end_token"] += num_sub_tokens - 1
old = input_ids[answer["start_token"] : answer["end_token"] + 1] # right & left are inclusive
start_token = answer["start_token"]
end_token = answer["end_token"]
if assertion:
"""This won't match exactly because of extra gaps => visaully inspect everything"""
new = tokenizer.decode(old)
if answer["span"] != new:
print("ISSUE IN TOKENIZATION")
print("OLD:", answer["span"])
print("NEW:", new, end="\n\n")
if len(input_ids) <= max_length:
return {
"example_id": example["id"],
"input_ids": [input_ids],
"labels": {
"start_token": [answer["start_token"]],
"end_token": [answer["end_token"]],
"category": answer["category"],
},
}
q_indices = input_ids[:q_len]
doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride)
inputs = []
answers_start_token = []
answers_end_token = []
answers_category = [] # null, yes, no, long, short
for i in doc_start_indices:
end_index = i + max_length - q_len
slice = input_ids[i:end_index]
inputs.append(q_indices + slice)
assert len(inputs[-1]) <= max_length, "Issue in truncating length"
if start_token >= i and end_token <= end_index - 1:
start_token = start_token - i + q_len
end_token = end_token - i + q_len
answers_category.append(answer["category"][0]) # ["short"] -> "short"
else:
start_token = -100
end_token = -100
answers_category.append("null")
new = inputs[-1][start_token : end_token + 1]
answers_start_token.append(start_token)
answers_end_token.append(end_token)
if assertion:
"""checking if above code is working as expected for all the samples"""
if new != old and new != [tokenizer.cls_token_id]:
print("ISSUE in strided for ID:", example["id"])
print("New:", tokenizer.decode(new))
print("Old:", tokenizer.decode(old), end="\n\n")
if slice[-1] == tokenizer.sep_token_id:
break
return {
"example_id": example["id"],
"input_ids": inputs,
"labels": {
"start_token": answers_start_token,
"end_token": answers_end_token,
"category": answers_category,
},
}
def prepare_inputs(example, tokenizer, doc_stride=2048, max_length=4096, assertion=False):
example = get_strided_contexts_and_ans(
example,
tokenizer,
doc_stride=doc_stride,
max_length=max_length,
assertion=assertion,
)
return example
def save_to_disk(hf_data, file_name):
with jsonlines.open(file_name, "a") as writer:
for example in tqdm(hf_data, total=len(hf_data), desc="Saving samples ... "):
labels = example["labels"]
for ids, start, end, cat in zip(
example["input_ids"],
labels["start_token"],
labels["end_token"],
labels["category"],
):
if start == -1 and end == -1:
continue # leave waste samples with no answer
if cat == "null" and np.random.rand() < 0.6:
continue # removing 50 % samples
writer.write(
{
"input_ids": ids,
"start_token": start,
"end_token": end,
"category": CATEGORY_MAPPING[cat],
}
)
if __name__ == "__main__":
"""Running area"""
from datasets import load_dataset
from transformers import BigBirdTokenizer
data = load_dataset("natural_questions")
tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base")
data = data["train" if PROCESS_TRAIN == "true" else "validation"]
fn_kwargs = dict(
tokenizer=tokenizer,
doc_stride=DOC_STRIDE,
max_length=MAX_LENGTH,
assertion=False,
)
data = data.map(prepare_inputs, fn_kwargs=fn_kwargs)
data = data.remove_columns(["annotations", "document", "id", "question"])
print(data)
np.random.seed(SEED)
cache_file_name = "nq-training.jsonl" if PROCESS_TRAIN == "true" else "nq-validation.jsonl"
save_to_disk(data, file_name=cache_file_name)

View File

@ -0,0 +1,6 @@
git+https://github.com/huggingface/transformers@master
datasets
sentencepiece
wandb
flax
jsonlines

View File

@ -0,0 +1,16 @@
command:
- python3
- train.py
method: random
parameters:
lr:
values: [4e-5, 3e-5]
warmup_steps:
values: [20000, 15000, 10000, 5000]
weight_decay:
distribution: normal
mu: 1e-2
sigma: 2e-3
metric:
name: eval_loss
goal: minimize

View File

@ -0,0 +1,78 @@
import os
from dataclasses import replace
from datasets import load_dataset
import jax
import wandb
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
from flax import jax_utils
from transformers import BigBirdTokenizerFast
if __name__ == "__main__":
print("#################### AVAILABLE DEVICES ####################")
print(jax.devices())
print("###########################################################")
# setup for wandb sweep
args = Args()
logger = wandb.init(project="bigbird-natural-questions", config=args.__dict__)
wandb_args = dict(logger.config)
del wandb_args["batch_size"]
args = replace(args, **wandb_args)
base_dir = args.base_dir + "-" + wandb.run.id
args = replace(args, base_dir=base_dir)
print(args)
tr_dataset = load_dataset("json", data_files=args.tr_data_path)["train"]
val_dataset = load_dataset("json", data_files=args.val_data_path)["train"]
# drop extra batch for now
indices = range(len(tr_dataset) - len(tr_dataset) % args.batch_size)
tr_dataset = tr_dataset.shuffle().select(indices)
indices = range(len(val_dataset) - len(val_dataset) % args.batch_size)
val_dataset = val_dataset.shuffle().select(indices)
if os.environ.get("TRAIN_ON_SMALL", "false") == "true":
tr_dataset = tr_dataset.shuffle().select(range(80000))
val_dataset = val_dataset.shuffle().select(range(8000))
print(tr_dataset)
print(val_dataset)
model = FlaxBigBirdForNaturalQuestions.from_pretrained(
args.model_id, block_size=args.block_size, num_random_blocks=args.num_random_blocks
)
tokenizer = BigBirdTokenizerFast.from_pretrained(args.model_id)
data_collator = DataCollator(pad_id=tokenizer.pad_token_id, max_length=4096)
tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
"warmup_steps": args.warmup_steps,
"num_train_steps": args.max_epochs * (len(tr_dataset) // args.batch_size),
"weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)
trainer = Trainer(
args=args,
data_collator=data_collator,
model_save_fn=model.save_pretrained,
train_step_fn=train_step,
val_step_fn=val_step,
logger=logger,
scheduler_fn=lr,
)
ckpt_dir = None
state = trainer.create_state(model, tx, num_train_steps=tx_args["num_train_steps"], ckpt_dir=ckpt_dir)
try:
trainer.train(state, tr_dataset, val_dataset)
except KeyboardInterrupt:
print("Oooops; TRAINING STOPPED UNFORTUNATELY")
print("SAVING WEIGHTS IN `final-weights`")
params = jax_utils.unreplicate(state.params)
model.save_pretrained(os.path.join(args.base_dir, "final-weights"), params=params)