diff --git a/examples/research_projects/jax-projects/big_bird/README.md b/examples/research_projects/jax-projects/big_bird/README.md new file mode 100644 index 0000000000..36e2f52a79 --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/README.md @@ -0,0 +1,60 @@ + +Author: [@vasudevgupta7](https://github.com/vasudevgupta7) + +## Intro + +In this project, we fine-tuned [**BigBird**](https://arxiv.org/abs/2007.14062) on [**natural-questions**](https://huggingface.co/datasets/natural_questions) dataset for **question-answering** task on long documents. **BigBird**, is a **sparse-attention based transformer** which extends Transformer based models, such as BERT to much **longer sequences**. + +Read more about BigBird at https://huggingface.co/blog/big-bird + +## Fine-tuning + +**Setup** + +You need to install jax yourself by following the official docs ([refer this](https://github.com/google/jax#installation)). Other requirements for this project can be installed by running following command: + +```shell +pip3 install -qr requirements.txt +``` + +**Download & prepare dataset** + +The Natural Questions corpus contains questions from real users, and it requires QA systems to read and comprehend an entire Wikipedia article that may or may not contain the answer to the question. This corpus takes ~100 GB on disk. We have used HuggingFace datasets to download & process the dataset. + +```shell +# just run following CMD +python3 prepare_natural_questions.py + +# this will download the whole dataset from HuggingFace Hub & will make it ready for training +# this script takes ~3 hours to process the dataset +``` + +**Launch Training** + +We have trained on Cloud's TPU v3-8. Each epoch took around 4.5 hours and the model got converged in just 2 epochs. You can see complete training args in [this script](bigbird_flax.py). + +```shell +# just run following CMD +python3 train.py + +# In case, you want to try hparams tuning, you can run wandb sweep +wandb sweep --project=bigbird sweep_flax.yaml +wandb agent +``` + +## Evaluation + +Our evaluation script is different from the original script and we are evaluating sequences with length up to 4096 for simplicity. We managed to get the **EM score of ~55.2** using our evaluation script. + +```shell +# download validation-dataset first +mkdir natural-questions-validation +wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/natural_questions-validation.arrow -P natural-questions-validation +wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/dataset_info.json -P natural-questions-validation +wget https://huggingface.co/datasets/vasudevgupta/natural-questions-validation/resolve/main/state.json -P natural-questions-validation + +# simply run following command +python3 evaluate.py +``` + +You can find our checkpoint on HuggingFace Hub ([see this](https://huggingface.co/vasudevgupta/flax-bigbird-natural-questions)). In case you are interested in PyTorch BigBird fine-tuning, you can refer to [this repositary](https://github.com/vasudevgupta7/bigbird). diff --git a/examples/research_projects/jax-projects/big_bird/bigbird_flax.py b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py new file mode 100644 index 0000000000..d272125472 --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/bigbird_flax.py @@ -0,0 +1,321 @@ +import json +import os +from dataclasses import dataclass +from functools import partial +from typing import Callable + +from tqdm.auto import tqdm + +import flax.linen as nn +import jax +import jax.numpy as jnp +import joblib +import optax +import wandb +from flax import jax_utils, struct, traverse_util +from flax.serialization import from_bytes, to_bytes +from flax.training import train_state +from flax.training.common_utils import shard +from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering +from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule + + +class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule): + """ + BigBirdForQuestionAnswering with CLS Head over the top for predicting category + + This way we can load its weights with FlaxBigBirdForQuestionAnswering + """ + + config: BigBirdConfig + dtype: jnp.dtype = jnp.float32 + add_pooling_layer: bool = True + + def setup(self): + super().setup() + self.cls = nn.Dense(5, dtype=self.dtype) + + def __call__(self, *args, **kwargs): + outputs = super().__call__(*args, **kwargs) + cls_out = self.cls(outputs[2]) + return outputs[:2] + (cls_out,) + + +class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering): + module_class = FlaxBigBirdForNaturalQuestionsModule + + +def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels): + def cross_entropy(logits, labels, reduction=None): + """ + Args: + logits: bsz, seqlen, vocab_size + labels: bsz, seqlen + """ + vocab_size = logits.shape[-1] + labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4") + logits = jax.nn.log_softmax(logits, axis=-1) + loss = -jnp.sum(labels * logits, axis=-1) + if reduction is not None: + loss = reduction(loss) + return loss + + cross_entropy = partial(cross_entropy, reduction=jnp.mean) + start_loss = cross_entropy(start_logits, start_labels) + end_loss = cross_entropy(end_logits, end_labels) + pooled_loss = cross_entropy(pooled_logits, pooler_labels) + return (start_loss + end_loss + pooled_loss) / 3 + + +@dataclass +class Args: + model_id: str = "google/bigbird-roberta-base" + logging_steps: int = 3000 + save_steps: int = 10500 + + block_size: int = 128 + num_random_blocks: int = 3 + + batch_size_per_device: int = 1 + max_epochs: int = 5 + + # tx_args + lr: float = 3e-5 + init_lr: float = 0.0 + warmup_steps: int = 20000 + weight_decay: float = 0.0095 + + save_dir: str = "bigbird-roberta-natural-questions" + base_dir: str = "training-expt" + tr_data_path: str = "data/nq-training.jsonl" + val_data_path: str = "data/nq-validation.jsonl" + + def __post_init__(self): + os.makedirs(self.base_dir, exist_ok=True) + self.save_dir = os.path.join(self.base_dir, self.save_dir) + self.batch_size = self.batch_size_per_device * jax.device_count() + + +@dataclass +class DataCollator: + + pad_id: int + max_length: int = 4096 # no dynamic padding on TPUs + + def __call__(self, batch): + batch = self.collate_fn(batch) + batch = jax.tree_map(shard, batch) + return batch + + def collate_fn(self, features): + input_ids, attention_mask = self.fetch_inputs(features["input_ids"]) + batch = { + "input_ids": jnp.array(input_ids, dtype=jnp.int32), + "attention_mask": jnp.array(attention_mask, dtype=jnp.int32), + "start_labels": jnp.array(features["start_token"], dtype=jnp.int32), + "end_labels": jnp.array(features["end_token"], dtype=jnp.int32), + "pooled_labels": jnp.array(features["category"], dtype=jnp.int32), + } + return batch + + def fetch_inputs(self, input_ids: list): + inputs = [self._fetch_inputs(ids) for ids in input_ids] + return zip(*inputs) + + def _fetch_inputs(self, input_ids: list): + attention_mask = [1 for _ in range(len(input_ids))] + while len(input_ids) < self.max_length: + input_ids.append(self.pad_id) + attention_mask.append(0) + return input_ids, attention_mask + + +def get_batched_dataset(dataset, batch_size, seed=None): + if seed is not None: + dataset = dataset.shuffle(seed=seed) + for i in range(len(dataset) // batch_size): + batch = dataset[i * batch_size : (i + 1) * batch_size] + yield dict(batch) + + +@partial(jax.pmap, axis_name="batch") +def train_step(state, drp_rng, **model_inputs): + def loss_fn(params): + start_labels = model_inputs.pop("start_labels") + end_labels = model_inputs.pop("end_labels") + pooled_labels = model_inputs.pop("pooled_labels") + + outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True) + start_logits, end_logits, pooled_logits = outputs + + return state.loss_fn( + start_logits, + start_labels, + end_logits, + end_labels, + pooled_logits, + pooled_labels, + ) + + drp_rng, new_drp_rng = jax.random.split(drp_rng) + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") + grads = jax.lax.pmean(grads, "batch") + + state = state.apply_gradients(grads=grads) + return state, metrics, new_drp_rng + + +@partial(jax.pmap, axis_name="batch") +def val_step(state, **model_inputs): + start_labels = model_inputs.pop("start_labels") + end_labels = model_inputs.pop("end_labels") + pooled_labels = model_inputs.pop("pooled_labels") + + outputs = state.apply_fn(**model_inputs, params=state.params, train=False) + start_logits, end_logits, pooled_logits = outputs + + loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels) + metrics = jax.lax.pmean({"loss": loss}, axis_name="batch") + return metrics + + +class TrainState(train_state.TrainState): + loss_fn: Callable = struct.field(pytree_node=False) + + +@dataclass +class Trainer: + args: Args + data_collator: Callable + train_step_fn: Callable + val_step_fn: Callable + model_save_fn: Callable + logger: wandb + scheduler_fn: Callable = None + + def create_state(self, model, tx, num_train_steps, ckpt_dir=None): + params = model.params + state = TrainState.create( + apply_fn=model.__call__, + params=params, + tx=tx, + loss_fn=calculate_loss_for_nq, + ) + if ckpt_dir is not None: + params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state) + tx_args = { + "lr": args.lr, + "init_lr": args.init_lr, + "warmup_steps": args.warmup_steps, + "num_train_steps": num_train_steps, + "weight_decay": args.weight_decay, + } + tx, lr = build_tx(**tx_args) + state = train_state.TrainState( + step=step, + apply_fn=model.__call__, + params=params, + tx=tx, + opt_state=opt_state, + ) + self.args = args + self.data_collator = data_collator + self.scheduler_fn = lr + model.params = params + state = jax_utils.replicate(state) + return state + + def train(self, state, tr_dataset, val_dataset): + args = self.args + total = len(tr_dataset) // args.batch_size + + rng = jax.random.PRNGKey(0) + drp_rng = jax.random.split(rng, jax.device_count()) + for epoch in range(args.max_epochs): + running_loss = jnp.array(0, dtype=jnp.float32) + tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch) + i = 0 + for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"): + batch = self.data_collator(batch) + state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch) + running_loss += jax_utils.unreplicate(metrics["loss"]) + i += 1 + if i % args.logging_steps == 0: + state_step = jax_utils.unreplicate(state.step) + tr_loss = running_loss.item() / i + lr = self.scheduler_fn(state_step - 1) + + eval_loss = self.evaluate(state, val_dataset) + logging_dict = dict( + step=state_step.item(), eval_loss=eval_loss.item(), tr_loss=tr_loss, lr=lr.item() + ) + tqdm.write(str(logging_dict)) + self.logger.log(logging_dict, commit=True) + + if i % args.save_steps == 0: + self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state) + + def evaluate(self, state, dataset): + dataloader = get_batched_dataset(dataset, self.args.batch_size) + total = len(dataset) // self.args.batch_size + running_loss = jnp.array(0, dtype=jnp.float32) + i = 0 + for batch in tqdm(dataloader, total=total, desc="Evaluating ... "): + batch = self.data_collator(batch) + metrics = self.val_step_fn(state, **batch) + running_loss += jax_utils.unreplicate(metrics["loss"]) + i += 1 + return running_loss / i + + def save_checkpoint(self, save_dir, state): + state = jax_utils.unreplicate(state) + print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ") + self.model_save_fn(save_dir, params=state.params) + with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f: + f.write(to_bytes(state.opt_state)) + joblib.dump(self.args, os.path.join(save_dir, "args.joblib")) + joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib")) + with open(os.path.join(save_dir, "training_state.json"), "w") as f: + json.dump({"step": state.step.item()}, f) + print("DONE") + + +def restore_checkpoint(save_dir, state): + print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ") + with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f: + params = from_bytes(state.params, f.read()) + + with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f: + opt_state = from_bytes(state.opt_state, f.read()) + + args = joblib.load(os.path.join(save_dir, "args.joblib")) + data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib")) + + with open(os.path.join(save_dir, "training_state.json"), "r") as f: + training_state = json.load(f) + step = training_state["step"] + + print("DONE") + return params, opt_state, step, args, data_collator + + +def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps): + decay_steps = num_train_steps - warmup_steps + warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps) + decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps) + lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps]) + return lr + + +def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay): + def weight_decay_mask(params): + params = traverse_util.flatten_dict(params) + mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()} + return traverse_util.unflatten_dict(mask) + + lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps) + + tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask) + return tx, lr diff --git a/examples/research_projects/jax-projects/big_bird/evaluate.py b/examples/research_projects/jax-projects/big_bird/evaluate.py new file mode 100644 index 0000000000..d81db40a95 --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/evaluate.py @@ -0,0 +1,165 @@ +from datasets import load_from_disk + +import jax +import jax.numpy as jnp +from bigbird_flax import FlaxBigBirdForNaturalQuestions +from transformers import BigBirdTokenizerFast + + +CATEGORY_MAPPING = {0: "null", 1: "short", 2: "long", 3: "yes", 4: "no"} +PUNCTUATION_SET_TO_EXCLUDE = set("".join(["‘", "’", "´", "`", ".", ",", "-", '"'])) + + +def get_sub_answers(answers, begin=0, end=None): + return [" ".join(x.split(" ")[begin:end]) for x in answers if len(x.split(" ")) > 1] + + +def expand_to_aliases(given_answers, make_sub_answers=False): + if make_sub_answers: + # if answers are longer than one word, make sure a predictions is correct if it coresponds to the complete 1: or :-1 sub word + # *e.g.* if the correct answer contains a prefix such as "the", or "a" + given_answers = ( + given_answers + get_sub_answers(given_answers, begin=1) + get_sub_answers(given_answers, end=-1) + ) + answers = [] + for answer in given_answers: + alias = answer.replace("_", " ").lower() + alias = "".join(c if c not in PUNCTUATION_SET_TO_EXCLUDE else " " for c in alias) + answers.append(" ".join(alias.split()).strip()) + return set(answers) + + +def get_best_valid_start_end_idx(start_scores, end_scores, top_k=1, max_size=100): + best_start_scores, best_start_idx = jax.lax.top_k(start_scores, top_k) + best_end_scores, best_end_idx = jax.lax.top_k(end_scores, top_k) + + widths = best_end_idx[:, None] - best_start_idx[None, :] + mask = jnp.logical_or(widths < 0, widths > max_size) + scores = (best_end_scores[:, None] + best_start_scores[None, :]) - (1e8 * mask) + best_score = jnp.argmax(scores).item() + + return best_start_idx[best_score % top_k], best_end_idx[best_score // top_k] + + +def format_dataset(sample): + question = sample["question"]["text"] + context = sample["document"]["tokens"]["token"] + is_html = sample["document"]["tokens"]["is_html"] + long_answers = sample["annotations"]["long_answer"] + short_answers = sample["annotations"]["short_answers"] + + context_string = " ".join([context[i] for i in range(len(context)) if not is_html[i]]) + + # 0 - No ; 1 - Yes + for answer in sample["annotations"]["yes_no_answer"]: + if answer == 0 or answer == 1: + return { + "question": question, + "context": context_string, + "short": [], + "long": [], + "category": "no" if answer == 0 else "yes", + } + + short_targets = [] + for s in short_answers: + short_targets.extend(s["text"]) + short_targets = list(set(short_targets)) + + long_targets = [] + for s in long_answers: + if s["start_token"] == -1: + continue + answer = context[s["start_token"] : s["end_token"]] + html = is_html[s["start_token"] : s["end_token"]] + new_answer = " ".join([answer[i] for i in range(len(answer)) if not html[i]]) + if new_answer not in long_targets: + long_targets.append(new_answer) + + category = "long_short" if len(short_targets + long_targets) > 0 else "null" + + return { + "question": question, + "context": context_string, + "short": short_targets, + "long": long_targets, + "category": category, + } + + +def main(): + dataset = load_from_disk("natural-questions-validation") + dataset = dataset.map(format_dataset).remove_columns(["annotations", "document", "id"]) + print(dataset) + + short_validation_dataset = dataset.filter(lambda x: (len(x["question"]) + len(x["context"])) < 4 * 4096) + short_validation_dataset = short_validation_dataset.filter(lambda x: x["category"] != "null") + short_validation_dataset + + model_id = "vasudevgupta/flax-bigbird-natural-questions" + model = FlaxBigBirdForNaturalQuestions.from_pretrained(model_id) + tokenizer = BigBirdTokenizerFast.from_pretrained(model_id) + + @jax.jit + def forward(*args, **kwargs): + start_logits, end_logits, pooled_logits = model(*args, **kwargs) + return start_logits, end_logits, jnp.argmax(pooled_logits, axis=-1) + + def evaluate(example): + # encode question and context so that they are seperated by a tokenizer.sep_token and cut at max_length + inputs = tokenizer( + example["question"], + example["context"], + return_tensors="jax", + max_length=4096, + padding="max_length", + truncation=True, + ) + + start_scores, end_scores, category = forward(**inputs) + + predicted_category = CATEGORY_MAPPING[category.item()] + + example["targets"] = example["long"] + example["short"] + if example["category"] in ["yes", "no", "null"]: + example["targets"] = [example["category"]] + example["has_tgt"] = example["category"] != "null" + # Now target can be: "yes", "no", "null", "list of long & short answers" + + if predicted_category in ["yes", "no", "null"]: + example["output"] = [predicted_category] + example["match"] = example["output"] == example["targets"] + example["has_pred"] = predicted_category != "null" + return example + + max_size = 38 if predicted_category == "short" else 1024 + start_score, end_score = get_best_valid_start_end_idx( + start_scores[0], end_scores[0], top_k=8, max_size=max_size + ) + + input_ids = inputs["input_ids"][0].tolist() + example["output"] = [tokenizer.decode(input_ids[start_score : end_score + 1])] + + answers = expand_to_aliases(example["targets"], make_sub_answers=True) + predictions = expand_to_aliases(example["output"]) + + # some preprocessing to both prediction and answer + answers = set(["".join(a.split()) for a in answers]) + predictions = set(["".join(p.split()) for p in predictions]) + predictions = set([s for s in predictions if s not in ["``", "''", "`", "'"]]) + + # if there is a common element, it's a exact match + example["match"] = len(list(answers & predictions)) > 0 + example["has_pred"] = predicted_category != "null" and len(predictions) > 0 + + return example + + short_validation_dataset = short_validation_dataset.map(evaluate) + + total = len(short_validation_dataset) + matched = len(short_validation_dataset.filter(lambda x: x["match"] == 1)) + print("EM score:", (matched / total) * 100, "%") + + +if __name__ == "__main__": + main() diff --git a/examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py b/examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py new file mode 100644 index 0000000000..8d2f69031e --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/prepare_natural_questions.py @@ -0,0 +1,330 @@ +import os + +import numpy as np +from tqdm import tqdm + +import jsonlines + + +DOC_STRIDE = 2048 +MAX_LENGTH = 4096 +SEED = 42 +PROCESS_TRAIN = os.environ.pop("PROCESS_TRAIN", "false") +CATEGORY_MAPPING = {"null": 0, "short": 1, "long": 2, "yes": 3, "no": 4} + + +def _get_single_answer(example): + def choose_first(answer, is_long_answer=False): + assert isinstance(answer, list) + if len(answer) == 1: + answer = answer[0] + return {k: [answer[k]] for k in answer} if is_long_answer else answer + for a in answer: + if is_long_answer: + a = {k: [a[k]] for k in a} + if len(a["start_token"]) > 0: + break + return a + + answer = {"id": example["id"]} + annotation = example["annotations"] + yes_no_answer = annotation["yes_no_answer"] + if 0 in yes_no_answer or 1 in yes_no_answer: + answer["category"] = ["yes"] if 1 in yes_no_answer else ["no"] + answer["start_token"] = answer["end_token"] = [] + answer["start_byte"] = answer["end_byte"] = [] + answer["text"] = [""] + else: + answer["category"] = ["short"] + out = choose_first(annotation["short_answers"]) + if len(out["start_token"]) == 0: + # answer will be long if short is not available + answer["category"] = ["long"] + out = choose_first(annotation["long_answer"], is_long_answer=True) + out["text"] = [] + answer.update(out) + + # disregard some samples + if len(answer["start_token"]) > 1 or answer["start_token"] == answer["end_token"]: + answer["remove_it"] = True + else: + answer["remove_it"] = False + + cols = ["start_token", "end_token", "start_byte", "end_byte", "text"] + if not all([isinstance(answer[k], list) for k in cols]): + raise ValueError("Issue in ID", example["id"]) + + return answer + + +def get_context_and_ans(example, assertion=False): + """Gives new context after removing & new answer tokens as per new context""" + answer = _get_single_answer(example) + # bytes are of no use + del answer["start_byte"] + del answer["end_byte"] + + # handle yes_no answers explicitly + if answer["category"][0] in ["yes", "no"]: # category is list with one element + doc = example["document"]["tokens"] + context = [] + for i in range(len(doc["token"])): + if not doc["is_html"][i]: + context.append(doc["token"][i]) + return { + "context": " ".join(context), + "answer": { + "start_token": -100, # ignore index in cross-entropy + "end_token": -100, # ignore index in cross-entropy + "category": answer["category"], + "span": answer["category"], # extra + }, + } + + # later, help in removing all no answers + if answer["start_token"] == [-1]: + return { + "context": "None", + "answer": { + "start_token": -1, + "end_token": -1, + "category": "null", + "span": "None", # extra + }, + } + + # handling normal samples + + cols = ["start_token", "end_token"] + answer.update({k: answer[k][0] if len(answer[k]) > 0 else answer[k] for k in cols}) # e.g. [10] == 10 + + doc = example["document"]["tokens"] + start_token = answer["start_token"] + end_token = answer["end_token"] + + context = [] + for i in range(len(doc["token"])): + if not doc["is_html"][i]: + context.append(doc["token"][i]) + else: + if answer["start_token"] > i: + start_token -= 1 + if answer["end_token"] > i: + end_token -= 1 + new = " ".join(context[start_token:end_token]) + + # checking above code + if assertion: + """checking if above code is working as expected for all the samples""" + is_html = doc["is_html"][answer["start_token"] : answer["end_token"]] + old = doc["token"][answer["start_token"] : answer["end_token"]] + old = " ".join([old[i] for i in range(len(old)) if not is_html[i]]) + if new != old: + print("ID:", example["id"]) + print("New:", new, end="\n") + print("Old:", old, end="\n\n") + + return { + "context": " ".join(context), + "answer": { + "start_token": start_token, + "end_token": end_token - 1, # this makes it inclusive + "category": answer["category"], # either long or short + "span": new, # extra + }, + } + + +def get_strided_contexts_and_ans(example, tokenizer, doc_stride=2048, max_length=4096, assertion=True): + # overlap will be of doc_stride - q_len + + out = get_context_and_ans(example, assertion=assertion) + answer = out["answer"] + + # later, removing these samples + if answer["start_token"] == -1: + return { + "example_id": example["id"], + "input_ids": [[-1]], + "labels": { + "start_token": [-1], + "end_token": [-1], + "category": ["null"], + }, + } + + input_ids = tokenizer(example["question"]["text"], out["context"]).input_ids + q_len = input_ids.index(tokenizer.sep_token_id) + 1 + + # return yes/no + if answer["category"][0] in ["yes", "no"]: # category is list with one element + inputs = [] + category = [] + q_indices = input_ids[:q_len] + doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride) + for i in doc_start_indices: + end_index = i + max_length - q_len + slice = input_ids[i:end_index] + inputs.append(q_indices + slice) + category.append(answer["category"][0]) + if slice[-1] == tokenizer.sep_token_id: + break + + return { + "example_id": example["id"], + "input_ids": inputs, + "labels": { + "start_token": [-100] * len(category), + "end_token": [-100] * len(category), + "category": category, + }, + } + + splitted_context = out["context"].split() + complete_end_token = splitted_context[answer["end_token"]] + answer["start_token"] = len( + tokenizer( + " ".join(splitted_context[: answer["start_token"]]), + add_special_tokens=False, + ).input_ids + ) + answer["end_token"] = len( + tokenizer(" ".join(splitted_context[: answer["end_token"]]), add_special_tokens=False).input_ids + ) + + answer["start_token"] += q_len + answer["end_token"] += q_len + + # fixing end token + num_sub_tokens = len(tokenizer(complete_end_token, add_special_tokens=False).input_ids) + if num_sub_tokens > 1: + answer["end_token"] += num_sub_tokens - 1 + + old = input_ids[answer["start_token"] : answer["end_token"] + 1] # right & left are inclusive + start_token = answer["start_token"] + end_token = answer["end_token"] + + if assertion: + """This won't match exactly because of extra gaps => visaully inspect everything""" + new = tokenizer.decode(old) + if answer["span"] != new: + print("ISSUE IN TOKENIZATION") + print("OLD:", answer["span"]) + print("NEW:", new, end="\n\n") + + if len(input_ids) <= max_length: + return { + "example_id": example["id"], + "input_ids": [input_ids], + "labels": { + "start_token": [answer["start_token"]], + "end_token": [answer["end_token"]], + "category": answer["category"], + }, + } + + q_indices = input_ids[:q_len] + doc_start_indices = range(q_len, len(input_ids), max_length - doc_stride) + + inputs = [] + answers_start_token = [] + answers_end_token = [] + answers_category = [] # null, yes, no, long, short + for i in doc_start_indices: + end_index = i + max_length - q_len + slice = input_ids[i:end_index] + inputs.append(q_indices + slice) + assert len(inputs[-1]) <= max_length, "Issue in truncating length" + + if start_token >= i and end_token <= end_index - 1: + start_token = start_token - i + q_len + end_token = end_token - i + q_len + answers_category.append(answer["category"][0]) # ["short"] -> "short" + else: + start_token = -100 + end_token = -100 + answers_category.append("null") + new = inputs[-1][start_token : end_token + 1] + + answers_start_token.append(start_token) + answers_end_token.append(end_token) + if assertion: + """checking if above code is working as expected for all the samples""" + if new != old and new != [tokenizer.cls_token_id]: + print("ISSUE in strided for ID:", example["id"]) + print("New:", tokenizer.decode(new)) + print("Old:", tokenizer.decode(old), end="\n\n") + if slice[-1] == tokenizer.sep_token_id: + break + + return { + "example_id": example["id"], + "input_ids": inputs, + "labels": { + "start_token": answers_start_token, + "end_token": answers_end_token, + "category": answers_category, + }, + } + + +def prepare_inputs(example, tokenizer, doc_stride=2048, max_length=4096, assertion=False): + example = get_strided_contexts_and_ans( + example, + tokenizer, + doc_stride=doc_stride, + max_length=max_length, + assertion=assertion, + ) + + return example + + +def save_to_disk(hf_data, file_name): + with jsonlines.open(file_name, "a") as writer: + for example in tqdm(hf_data, total=len(hf_data), desc="Saving samples ... "): + labels = example["labels"] + for ids, start, end, cat in zip( + example["input_ids"], + labels["start_token"], + labels["end_token"], + labels["category"], + ): + if start == -1 and end == -1: + continue # leave waste samples with no answer + if cat == "null" and np.random.rand() < 0.6: + continue # removing 50 % samples + writer.write( + { + "input_ids": ids, + "start_token": start, + "end_token": end, + "category": CATEGORY_MAPPING[cat], + } + ) + + +if __name__ == "__main__": + """Running area""" + from datasets import load_dataset + + from transformers import BigBirdTokenizer + + data = load_dataset("natural_questions") + tokenizer = BigBirdTokenizer.from_pretrained("google/bigbird-roberta-base") + + data = data["train" if PROCESS_TRAIN == "true" else "validation"] + + fn_kwargs = dict( + tokenizer=tokenizer, + doc_stride=DOC_STRIDE, + max_length=MAX_LENGTH, + assertion=False, + ) + data = data.map(prepare_inputs, fn_kwargs=fn_kwargs) + data = data.remove_columns(["annotations", "document", "id", "question"]) + print(data) + + np.random.seed(SEED) + cache_file_name = "nq-training.jsonl" if PROCESS_TRAIN == "true" else "nq-validation.jsonl" + save_to_disk(data, file_name=cache_file_name) diff --git a/examples/research_projects/jax-projects/big_bird/requirements.txt b/examples/research_projects/jax-projects/big_bird/requirements.txt new file mode 100644 index 0000000000..4c9c2cb983 --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/requirements.txt @@ -0,0 +1,6 @@ +git+https://github.com/huggingface/transformers@master +datasets +sentencepiece +wandb +flax +jsonlines diff --git a/examples/research_projects/jax-projects/big_bird/sweep_flax.yaml b/examples/research_projects/jax-projects/big_bird/sweep_flax.yaml new file mode 100644 index 0000000000..d804f61b3e --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/sweep_flax.yaml @@ -0,0 +1,16 @@ +command: + - python3 + - train.py +method: random +parameters: + lr: + values: [4e-5, 3e-5] + warmup_steps: + values: [20000, 15000, 10000, 5000] + weight_decay: + distribution: normal + mu: 1e-2 + sigma: 2e-3 +metric: + name: eval_loss + goal: minimize diff --git a/examples/research_projects/jax-projects/big_bird/train.py b/examples/research_projects/jax-projects/big_bird/train.py new file mode 100644 index 0000000000..3d67c9d97f --- /dev/null +++ b/examples/research_projects/jax-projects/big_bird/train.py @@ -0,0 +1,78 @@ +import os +from dataclasses import replace + +from datasets import load_dataset + +import jax +import wandb +from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step +from flax import jax_utils +from transformers import BigBirdTokenizerFast + + +if __name__ == "__main__": + print("#################### AVAILABLE DEVICES ####################") + print(jax.devices()) + print("###########################################################") + + # setup for wandb sweep + args = Args() + logger = wandb.init(project="bigbird-natural-questions", config=args.__dict__) + wandb_args = dict(logger.config) + del wandb_args["batch_size"] + args = replace(args, **wandb_args) + base_dir = args.base_dir + "-" + wandb.run.id + args = replace(args, base_dir=base_dir) + print(args) + + tr_dataset = load_dataset("json", data_files=args.tr_data_path)["train"] + val_dataset = load_dataset("json", data_files=args.val_data_path)["train"] + + # drop extra batch for now + indices = range(len(tr_dataset) - len(tr_dataset) % args.batch_size) + tr_dataset = tr_dataset.shuffle().select(indices) + indices = range(len(val_dataset) - len(val_dataset) % args.batch_size) + val_dataset = val_dataset.shuffle().select(indices) + + if os.environ.get("TRAIN_ON_SMALL", "false") == "true": + tr_dataset = tr_dataset.shuffle().select(range(80000)) + val_dataset = val_dataset.shuffle().select(range(8000)) + + print(tr_dataset) + print(val_dataset) + + model = FlaxBigBirdForNaturalQuestions.from_pretrained( + args.model_id, block_size=args.block_size, num_random_blocks=args.num_random_blocks + ) + tokenizer = BigBirdTokenizerFast.from_pretrained(args.model_id) + data_collator = DataCollator(pad_id=tokenizer.pad_token_id, max_length=4096) + + tx_args = { + "lr": args.lr, + "init_lr": args.init_lr, + "warmup_steps": args.warmup_steps, + "num_train_steps": args.max_epochs * (len(tr_dataset) // args.batch_size), + "weight_decay": args.weight_decay, + } + tx, lr = build_tx(**tx_args) + + trainer = Trainer( + args=args, + data_collator=data_collator, + model_save_fn=model.save_pretrained, + train_step_fn=train_step, + val_step_fn=val_step, + logger=logger, + scheduler_fn=lr, + ) + + ckpt_dir = None + state = trainer.create_state(model, tx, num_train_steps=tx_args["num_train_steps"], ckpt_dir=ckpt_dir) + try: + trainer.train(state, tr_dataset, val_dataset) + except KeyboardInterrupt: + print("Oooops; TRAINING STOPPED UNFORTUNATELY") + + print("SAVING WEIGHTS IN `final-weights`") + params = jax_utils.unreplicate(state.params) + model.save_pretrained(os.path.join(args.base_dir, "final-weights"), params=params)