324 lines
12 KiB
Python
324 lines
12 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
from typing import Callable
|
|
|
|
import flax.linen as nn
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import joblib
|
|
import optax
|
|
import wandb
|
|
from flax import jax_utils, struct, traverse_util
|
|
from flax.serialization import from_bytes, to_bytes
|
|
from flax.training import train_state
|
|
from flax.training.common_utils import shard
|
|
from tqdm.auto import tqdm
|
|
|
|
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
|
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
|
|
|
|
|
class FlaxBigBirdForNaturalQuestionsModule(FlaxBigBirdForQuestionAnsweringModule):
|
|
"""
|
|
BigBirdForQuestionAnswering with CLS Head over the top for predicting category
|
|
|
|
This way we can load its weights with FlaxBigBirdForQuestionAnswering
|
|
"""
|
|
|
|
config: BigBirdConfig
|
|
dtype: jnp.dtype = jnp.float32
|
|
add_pooling_layer: bool = True
|
|
|
|
def setup(self):
|
|
super().setup()
|
|
self.cls = nn.Dense(5, dtype=self.dtype)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
outputs = super().__call__(*args, **kwargs)
|
|
cls_out = self.cls(outputs[2])
|
|
return outputs[:2] + (cls_out,)
|
|
|
|
|
|
class FlaxBigBirdForNaturalQuestions(FlaxBigBirdForQuestionAnswering):
|
|
module_class = FlaxBigBirdForNaturalQuestionsModule
|
|
|
|
|
|
def calculate_loss_for_nq(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooler_labels):
|
|
def cross_entropy(logits, labels, reduction=None):
|
|
"""
|
|
Args:
|
|
logits: bsz, seqlen, vocab_size
|
|
labels: bsz, seqlen
|
|
"""
|
|
vocab_size = logits.shape[-1]
|
|
labels = (labels[..., None] == jnp.arange(vocab_size)[None]).astype("f4")
|
|
logits = jax.nn.log_softmax(logits, axis=-1)
|
|
loss = -jnp.sum(labels * logits, axis=-1)
|
|
if reduction is not None:
|
|
loss = reduction(loss)
|
|
return loss
|
|
|
|
cross_entropy = partial(cross_entropy, reduction=jnp.mean)
|
|
start_loss = cross_entropy(start_logits, start_labels)
|
|
end_loss = cross_entropy(end_logits, end_labels)
|
|
pooled_loss = cross_entropy(pooled_logits, pooler_labels)
|
|
return (start_loss + end_loss + pooled_loss) / 3
|
|
|
|
|
|
@dataclass
|
|
class Args:
|
|
model_id: str = "google/bigbird-roberta-base"
|
|
logging_steps: int = 3000
|
|
save_steps: int = 10500
|
|
|
|
block_size: int = 128
|
|
num_random_blocks: int = 3
|
|
|
|
batch_size_per_device: int = 1
|
|
max_epochs: int = 5
|
|
|
|
# tx_args
|
|
lr: float = 3e-5
|
|
init_lr: float = 0.0
|
|
warmup_steps: int = 20000
|
|
weight_decay: float = 0.0095
|
|
|
|
save_dir: str = "bigbird-roberta-natural-questions"
|
|
base_dir: str = "training-expt"
|
|
tr_data_path: str = "data/nq-training.jsonl"
|
|
val_data_path: str = "data/nq-validation.jsonl"
|
|
|
|
def __post_init__(self):
|
|
os.makedirs(self.base_dir, exist_ok=True)
|
|
self.save_dir = os.path.join(self.base_dir, self.save_dir)
|
|
self.batch_size = self.batch_size_per_device * jax.device_count()
|
|
|
|
|
|
@dataclass
|
|
class DataCollator:
|
|
pad_id: int
|
|
max_length: int = 4096 # no dynamic padding on TPUs
|
|
|
|
def __call__(self, batch):
|
|
batch = self.collate_fn(batch)
|
|
batch = jax.tree_util.tree_map(shard, batch)
|
|
return batch
|
|
|
|
def collate_fn(self, features):
|
|
input_ids, attention_mask = self.fetch_inputs(features["input_ids"])
|
|
batch = {
|
|
"input_ids": jnp.array(input_ids, dtype=jnp.int32),
|
|
"attention_mask": jnp.array(attention_mask, dtype=jnp.int32),
|
|
"start_labels": jnp.array(features["start_token"], dtype=jnp.int32),
|
|
"end_labels": jnp.array(features["end_token"], dtype=jnp.int32),
|
|
"pooled_labels": jnp.array(features["category"], dtype=jnp.int32),
|
|
}
|
|
return batch
|
|
|
|
def fetch_inputs(self, input_ids: list):
|
|
inputs = [self._fetch_inputs(ids) for ids in input_ids]
|
|
return zip(*inputs)
|
|
|
|
def _fetch_inputs(self, input_ids: list):
|
|
attention_mask = [1 for _ in range(len(input_ids))]
|
|
while len(input_ids) < self.max_length:
|
|
input_ids.append(self.pad_id)
|
|
attention_mask.append(0)
|
|
return input_ids, attention_mask
|
|
|
|
|
|
def get_batched_dataset(dataset, batch_size, seed=None):
|
|
if seed is not None:
|
|
dataset = dataset.shuffle(seed=seed)
|
|
for i in range(len(dataset) // batch_size):
|
|
batch = dataset[i * batch_size : (i + 1) * batch_size]
|
|
yield dict(batch)
|
|
|
|
|
|
@partial(jax.pmap, axis_name="batch")
|
|
def train_step(state, drp_rng, **model_inputs):
|
|
def loss_fn(params):
|
|
start_labels = model_inputs.pop("start_labels")
|
|
end_labels = model_inputs.pop("end_labels")
|
|
pooled_labels = model_inputs.pop("pooled_labels")
|
|
|
|
outputs = state.apply_fn(**model_inputs, params=params, dropout_rng=drp_rng, train=True)
|
|
start_logits, end_logits, pooled_logits = outputs
|
|
|
|
return state.loss_fn(
|
|
start_logits,
|
|
start_labels,
|
|
end_logits,
|
|
end_labels,
|
|
pooled_logits,
|
|
pooled_labels,
|
|
)
|
|
|
|
drp_rng, new_drp_rng = jax.random.split(drp_rng)
|
|
grad_fn = jax.value_and_grad(loss_fn)
|
|
loss, grads = grad_fn(state.params)
|
|
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
|
|
grads = jax.lax.pmean(grads, "batch")
|
|
|
|
state = state.apply_gradients(grads=grads)
|
|
return state, metrics, new_drp_rng
|
|
|
|
|
|
@partial(jax.pmap, axis_name="batch")
|
|
def val_step(state, **model_inputs):
|
|
start_labels = model_inputs.pop("start_labels")
|
|
end_labels = model_inputs.pop("end_labels")
|
|
pooled_labels = model_inputs.pop("pooled_labels")
|
|
|
|
outputs = state.apply_fn(**model_inputs, params=state.params, train=False)
|
|
start_logits, end_logits, pooled_logits = outputs
|
|
|
|
loss = state.loss_fn(start_logits, start_labels, end_logits, end_labels, pooled_logits, pooled_labels)
|
|
metrics = jax.lax.pmean({"loss": loss}, axis_name="batch")
|
|
return metrics
|
|
|
|
|
|
class TrainState(train_state.TrainState):
|
|
loss_fn: Callable = struct.field(pytree_node=False)
|
|
|
|
|
|
@dataclass
|
|
class Trainer:
|
|
args: Args
|
|
data_collator: Callable
|
|
train_step_fn: Callable
|
|
val_step_fn: Callable
|
|
model_save_fn: Callable
|
|
logger: wandb
|
|
scheduler_fn: Callable = None
|
|
|
|
def create_state(self, model, tx, num_train_steps, ckpt_dir=None):
|
|
params = model.params
|
|
state = TrainState.create(
|
|
apply_fn=model.__call__,
|
|
params=params,
|
|
tx=tx,
|
|
loss_fn=calculate_loss_for_nq,
|
|
)
|
|
if ckpt_dir is not None:
|
|
params, opt_state, step, args, data_collator = restore_checkpoint(ckpt_dir, state)
|
|
tx_args = {
|
|
"lr": args.lr,
|
|
"init_lr": args.init_lr,
|
|
"warmup_steps": args.warmup_steps,
|
|
"num_train_steps": num_train_steps,
|
|
"weight_decay": args.weight_decay,
|
|
}
|
|
tx, lr = build_tx(**tx_args)
|
|
state = train_state.TrainState(
|
|
step=step,
|
|
apply_fn=model.__call__,
|
|
params=params,
|
|
tx=tx,
|
|
opt_state=opt_state,
|
|
)
|
|
self.args = args
|
|
self.data_collator = data_collator
|
|
self.scheduler_fn = lr
|
|
model.params = params
|
|
state = jax_utils.replicate(state)
|
|
return state
|
|
|
|
def train(self, state, tr_dataset, val_dataset):
|
|
args = self.args
|
|
total = len(tr_dataset) // args.batch_size
|
|
|
|
rng = jax.random.PRNGKey(0)
|
|
drp_rng = jax.random.split(rng, jax.device_count())
|
|
for epoch in range(args.max_epochs):
|
|
running_loss = jnp.array(0, dtype=jnp.float32)
|
|
tr_dataloader = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
|
|
i = 0
|
|
for batch in tqdm(tr_dataloader, total=total, desc=f"Running EPOCH-{epoch}"):
|
|
batch = self.data_collator(batch)
|
|
state, metrics, drp_rng = self.train_step_fn(state, drp_rng, **batch)
|
|
running_loss += jax_utils.unreplicate(metrics["loss"])
|
|
i += 1
|
|
if i % args.logging_steps == 0:
|
|
state_step = jax_utils.unreplicate(state.step)
|
|
tr_loss = running_loss.item() / i
|
|
lr = self.scheduler_fn(state_step - 1)
|
|
|
|
eval_loss = self.evaluate(state, val_dataset)
|
|
logging_dict = {
|
|
"step": state_step.item(),
|
|
"eval_loss": eval_loss.item(),
|
|
"tr_loss": tr_loss,
|
|
"lr": lr.item(),
|
|
}
|
|
tqdm.write(str(logging_dict))
|
|
self.logger.log(logging_dict, commit=True)
|
|
|
|
if i % args.save_steps == 0:
|
|
self.save_checkpoint(args.save_dir + f"-e{epoch}-s{i}", state=state)
|
|
|
|
def evaluate(self, state, dataset):
|
|
dataloader = get_batched_dataset(dataset, self.args.batch_size)
|
|
total = len(dataset) // self.args.batch_size
|
|
running_loss = jnp.array(0, dtype=jnp.float32)
|
|
i = 0
|
|
for batch in tqdm(dataloader, total=total, desc="Evaluating ... "):
|
|
batch = self.data_collator(batch)
|
|
metrics = self.val_step_fn(state, **batch)
|
|
running_loss += jax_utils.unreplicate(metrics["loss"])
|
|
i += 1
|
|
return running_loss / i
|
|
|
|
def save_checkpoint(self, save_dir, state):
|
|
state = jax_utils.unreplicate(state)
|
|
print(f"SAVING CHECKPOINT IN {save_dir}", end=" ... ")
|
|
self.model_save_fn(save_dir, params=state.params)
|
|
with open(os.path.join(save_dir, "opt_state.msgpack"), "wb") as f:
|
|
f.write(to_bytes(state.opt_state))
|
|
joblib.dump(self.args, os.path.join(save_dir, "args.joblib"))
|
|
joblib.dump(self.data_collator, os.path.join(save_dir, "data_collator.joblib"))
|
|
with open(os.path.join(save_dir, "training_state.json"), "w") as f:
|
|
json.dump({"step": state.step.item()}, f)
|
|
print("DONE")
|
|
|
|
|
|
def restore_checkpoint(save_dir, state):
|
|
print(f"RESTORING CHECKPOINT FROM {save_dir}", end=" ... ")
|
|
with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
|
|
params = from_bytes(state.params, f.read())
|
|
|
|
with open(os.path.join(save_dir, "opt_state.msgpack"), "rb") as f:
|
|
opt_state = from_bytes(state.opt_state, f.read())
|
|
|
|
args = joblib.load(os.path.join(save_dir, "args.joblib"))
|
|
data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
|
|
|
|
with open(os.path.join(save_dir, "training_state.json"), "r") as f:
|
|
training_state = json.load(f)
|
|
step = training_state["step"]
|
|
|
|
print("DONE")
|
|
return params, opt_state, step, args, data_collator
|
|
|
|
|
|
def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
|
|
decay_steps = num_train_steps - warmup_steps
|
|
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
|
|
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
|
|
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
|
|
return lr
|
|
|
|
|
|
def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
|
|
def weight_decay_mask(params):
|
|
params = traverse_util.flatten_dict(params)
|
|
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
|
|
return traverse_util.unflatten_dict(mask)
|
|
|
|
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
|
|
|
|
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
|
|
return tx, lr
|