From 7490a97cac20cef6858f32e5f39a61f31ad64552 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Wed, 27 Jul 2022 15:50:47 +0100 Subject: [PATCH] [Flax] Fix incomplete batches in example scripts (#17863) * [Flax] Fix incomplete batches in example scripts * fix dataloader batching * convert jnp batch idxs to np array * add missing `pad_shard_unpad` to final prediction generate step * only `pad_shard_unpad` at inference time * merge conflicts * remove incomplete batch step from eval * fix run_qa.py * add `pad_shard_unpad` to run_flax_ner.py * add `pad_shard_unpad` to run_flax_glue.py * add `pad_shard_unpad` to run_image_classification.py * make style * fix mlm flax eval batches * remove redundant imports --- .../flax/language-modeling/run_clm_flax.py | 44 ++++++----- .../flax/language-modeling/run_mlm_flax.py | 40 ++++++---- .../flax/language-modeling/run_t5_mlm_flax.py | 43 +++++++---- examples/flax/question-answering/run_qa.py | 77 +++++++------------ .../summarization/run_summarization_flax.py | 55 +++++++------ .../flax/text-classification/run_flax_glue.py | 42 +++++----- .../flax/token-classification/run_flax_ner.py | 64 +++++---------- .../flax/vision/run_image_classification.py | 12 +-- 8 files changed, 180 insertions(+), 197 deletions(-) diff --git a/examples/flax/language-modeling/run_clm_flax.py b/examples/flax/language-modeling/run_clm_flax.py index 00fc6e61f7..5fe786da7c 100755 --- a/examples/flax/language-modeling/run_clm_flax.py +++ b/examples/flax/language-modeling/run_clm_flax.py @@ -43,7 +43,7 @@ import jax.numpy as jnp import optax import transformers from flax import jax_utils, traverse_util -from flax.jax_utils import unreplicate +from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from huggingface_hub import Repository @@ -264,20 +264,24 @@ class TrainState(train_state.TrainState): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) -def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ - Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. - Shuffle batches if `shuffle` is `True`. + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ - steps_per_epoch = len(dataset) // batch_size - if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) else: - batch_idx = jnp.arange(len(dataset)) + batch_idx = np.arange(len(dataset)) - batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. - batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] @@ -621,7 +625,8 @@ def main(): # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() - eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -764,13 +769,14 @@ def main(): if cur_step % training_args.eval_steps == 0 and cur_step > 0: # ======================== Evaluating ============================== eval_metrics = [] - eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) - eval_steps = len(eval_dataset) // eval_batch_size + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) - batch = shard(batch) - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # normalize eval metrics @@ -806,12 +812,14 @@ def main(): # Eval after training if training_args.do_eval: eval_metrics = [] - eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) - eval_steps = len(eval_dataset) // eval_batch_size + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward - batch = shard(next(eval_loader)) - metrics = p_eval_step(state.params, batch) + batch = next(eval_loader) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # normalize eval metrics diff --git a/examples/flax/language-modeling/run_mlm_flax.py b/examples/flax/language-modeling/run_mlm_flax.py index 9657471246..4b0c8c803b 100755 --- a/examples/flax/language-modeling/run_mlm_flax.py +++ b/examples/flax/language-modeling/run_mlm_flax.py @@ -43,6 +43,7 @@ import jax import jax.numpy as jnp import optax from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -326,15 +327,20 @@ class FlaxDataCollatorForLanguageModeling: return inputs, labels -def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - batch_idx = np.split(samples_idx, sections_split) - return batch_idx + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx def write_train_metric(summary_writer, train_metrics, train_time, step): @@ -632,12 +638,14 @@ def main(): config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), + use_auth_token=True if model_args.use_auth_token else None, ) # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() - eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs @@ -796,7 +804,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -804,8 +812,9 @@ def main(): model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # normalize eval metrics @@ -835,7 +844,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for _, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -843,8 +852,9 @@ def main(): model_inputs = data_collator(samples, pad_to_multiple_of=16) # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # normalize eval metrics diff --git a/examples/flax/language-modeling/run_t5_mlm_flax.py b/examples/flax/language-modeling/run_t5_mlm_flax.py index ad0b43d3d6..e0943ffdfb 100755 --- a/examples/flax/language-modeling/run_t5_mlm_flax.py +++ b/examples/flax/language-modeling/run_t5_mlm_flax.py @@ -21,6 +21,7 @@ https://huggingface.co/models?filter=t5 """ import json import logging +import math import os import sys import time @@ -41,6 +42,7 @@ import jax import jax.numpy as jnp import optax from flax import jax_utils, traverse_util +from flax.jax_utils import pad_shard_unpad from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -326,6 +328,7 @@ class FlaxDataCollatorForT5MLM: decoder_start_token_id: int def __call__(self, examples: List[Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: + # convert list to dict and tensorize input batch = BatchEncoding( {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()} @@ -394,6 +397,7 @@ class FlaxDataCollatorForT5MLM: return input_ids def random_spans_noise_mask(self, length): + """This function is copy of `random_spans_helper `__ . Noise mask consisting of random spans of noise tokens. @@ -457,15 +461,20 @@ class FlaxDataCollatorForT5MLM: return is_noise[:orig_length] -def generate_batch_splits(samples_idx: np.ndarray, batch_size: int) -> np.ndarray: +def generate_batch_splits(samples_idx: np.ndarray, batch_size: int, drop_last=True) -> np.ndarray: + """Generate batches of data for a specified batch size from sample indices. If the dataset size is not divisible by + the batch size and `drop_last` is `True`, the last incomplete batch is dropped. Else, it is returned.""" num_samples = len(samples_idx) - samples_to_remove = num_samples % batch_size - - if samples_to_remove != 0: - samples_idx = samples_idx[:-samples_to_remove] - sections_split = num_samples // batch_size - batch_idx = np.split(samples_idx, sections_split) - return batch_idx + if drop_last: + samples_to_remove = num_samples % batch_size + if samples_to_remove != 0: + samples_idx = samples_idx[:-samples_to_remove] + sections_split = num_samples // batch_size + samples_idx = samples_idx.reshape((sections_split, batch_size)) + else: + sections_split = math.ceil(num_samples / batch_size) + samples_idx = np.array_split(samples_idx, sections_split) + return samples_idx def write_train_metric(summary_writer, train_metrics, train_time, step): @@ -737,6 +746,7 @@ def main(): config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), + use_auth_token=True if model_args.use_auth_token else None, ) # Data collator @@ -754,7 +764,8 @@ def main(): # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() - eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() num_train_steps = len(tokenized_datasets["train"]) // train_batch_size * num_epochs @@ -915,7 +926,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -923,8 +934,9 @@ def main(): model_inputs = data_collator(samples) # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # get eval metrics @@ -952,7 +964,7 @@ def main(): num_eval_samples = len(tokenized_datasets["validation"]) # Avoid using jax.numpy here in case of TPU training eval_samples_idx = np.arange(num_eval_samples) - eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size) + eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size, drop_last=False) eval_metrics = [] for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=2)): @@ -960,8 +972,9 @@ def main(): model_inputs = data_collator(samples) # Model forward - model_inputs = shard(model_inputs.data) - metrics = p_eval_step(state.params, model_inputs) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, model_inputs.data, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # get eval metrics diff --git a/examples/flax/question-answering/run_qa.py b/examples/flax/question-answering/run_qa.py index b424756355..05315c0b6a 100644 --- a/examples/flax/question-answering/run_qa.py +++ b/examples/flax/question-answering/run_qa.py @@ -20,13 +20,13 @@ Fine-tuning the library models for question answering. import json import logging +import math import os import random import sys import time from dataclasses import asdict, dataclass, field from enum import Enum -from itertools import chain from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple @@ -40,7 +40,7 @@ import jax.numpy as jnp import optax import transformers from flax import struct, traverse_util -from flax.jax_utils import replicate, unreplicate +from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -406,11 +406,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): # region eval data iterator def eval_data_collator(dataset: Dataset, batch_size: int): - """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" - for i in range(len(dataset) // batch_size): - batch = dataset[i * batch_size : (i + 1) * batch_size] + """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop.""" + batch_idx = np.arange(len(dataset)) + + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} - batch = shard(batch) yield batch @@ -856,8 +860,9 @@ def main(): rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() - eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() + train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.local_device_count() # endregion # region Load model @@ -975,32 +980,17 @@ def main(): # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), - total=len(eval_dataset) // eval_batch_size, + total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") - predictions = p_eval_step(state, batch) - start_logits = np.array([pred for pred in chain(*predictions[0])]) - end_logits = np.array([pred for pred in chain(*predictions[1])]) - all_start_logits.append(start_logits) - all_end_logits.append(end_logits) - - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size - - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: np.array(v) for k, v in batch.items()} - _ = batch.pop("example_id") - _ = batch.pop("offset_mapping") - - predictions = eval_step(unreplicate(state), batch) - start_logits = np.array([pred for pred in predictions[0]]) - end_logits = np.array([pred for pred in predictions[1]]) + predictions = pad_shard_unpad(p_eval_step)( + state, batch, min_device_batch=per_device_eval_batch_size + ) + start_logits = np.array(predictions[0]) + end_logits = np.array(predictions[1]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) @@ -1039,30 +1029,15 @@ def main(): all_start_logits = [] all_end_logits = [] - eva_loader = eval_data_collator(eval_dataset, eval_batch_size) - for batch in tqdm(eva_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): + eval_loader = eval_data_collator(eval_dataset, eval_batch_size) + for batch in tqdm( + eval_loader, total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2 + ): _ = batch.pop("example_id") _ = batch.pop("offset_mapping") - predictions = p_eval_step(state, batch) - start_logits = np.array([pred for pred in chain(*predictions[0])]) - end_logits = np.array([pred for pred in chain(*predictions[1])]) - all_start_logits.append(start_logits) - all_end_logits.append(end_logits) - - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size - - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: np.array(v) for k, v in batch.items()} - _ = batch.pop("example_id") - _ = batch.pop("offset_mapping") - - predictions = eval_step(unreplicate(state), batch) - start_logits = np.array([pred for pred in predictions[0]]) - end_logits = np.array([pred for pred in predictions[1]]) + predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) + start_logits = np.array(predictions[0]) + end_logits = np.array(predictions[1]) all_start_logits.append(start_logits) all_end_logits.append(end_logits) diff --git a/examples/flax/summarization/run_summarization_flax.py b/examples/flax/summarization/run_summarization_flax.py index a1b5fc37e2..bd17141a44 100644 --- a/examples/flax/summarization/run_summarization_flax.py +++ b/examples/flax/summarization/run_summarization_flax.py @@ -20,6 +20,7 @@ Fine-tuning the library models for summarization. import json import logging +import math import os import sys import time @@ -41,7 +42,7 @@ import optax import transformers from filelock import FileLock from flax import jax_utils, traverse_util -from flax.jax_utils import unreplicate +from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from huggingface_hub import Repository @@ -335,26 +336,28 @@ class TrainState(train_state.TrainState): return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) -def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): +def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False, drop_last=True): """ - Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. - Shuffle batches if `shuffle` is `True`. + Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete, + and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`. """ - steps_per_epoch = len(dataset) // batch_size - if shuffle: batch_idx = jax.random.permutation(rng, len(dataset)) + batch_idx = np.asarray(batch_idx) else: - batch_idx = jnp.arange(len(dataset)) + batch_idx = np.arange(len(dataset)) - batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. - batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + if drop_last: + steps_per_epoch = len(dataset) // batch_size + batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch. + batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) + else: + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) for idx in batch_idx: batch = dataset[idx] - batch = {k: jnp.array(v) for k, v in batch.items()} - - batch = shard(batch) + batch = {k: np.array(v) for k, v in batch.items()} yield batch @@ -706,7 +709,8 @@ def main(): # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() - eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -850,6 +854,7 @@ def main(): # train for _ in tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False): batch = next(train_loader) + batch = shard(batch) state, train_metric = p_train_step(state, batch) train_metrics.append(train_metric) @@ -867,21 +872,23 @@ def main(): eval_preds = [] eval_labels = [] - eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size) - eval_steps = len(eval_dataset) // eval_batch_size + eval_loader = data_loader(input_rng, eval_dataset, eval_batch_size, drop_last=False) + eval_steps = math.ceil(len(eval_dataset) / eval_batch_size) for _ in tqdm(range(eval_steps), desc="Evaluating...", position=2, leave=False): # Model forward batch = next(eval_loader) labels = batch["labels"] - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) # generation if data_args.predict_with_generate: - generated_ids = p_generate_step(state.params, batch) + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) - eval_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + eval_labels.extend(labels) # normalize eval metrics eval_metrics = get_metrics(eval_metrics) @@ -920,21 +927,23 @@ def main(): pred_generations = [] pred_labels = [] - pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) - pred_steps = len(predict_dataset) // eval_batch_size + pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size, drop_last=False) + pred_steps = math.ceil(len(predict_dataset) / eval_batch_size) for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): # Model forward batch = next(pred_loader) labels = batch["labels"] - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) pred_metrics.append(metrics) # generation if data_args.predict_with_generate: - generated_ids = p_generate_step(state.params, batch) + generated_ids = pad_shard_unpad(p_generate_step)(state.params, batch) pred_generations.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"]))) - pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) + pred_labels.extend(labels) # normalize prediction metrics pred_metrics = get_metrics(pred_metrics) diff --git a/examples/flax/text-classification/run_flax_glue.py b/examples/flax/text-classification/run_flax_glue.py index d777123185..233a06c3ae 100755 --- a/examples/flax/text-classification/run_flax_glue.py +++ b/examples/flax/text-classification/run_flax_glue.py @@ -16,12 +16,12 @@ """ Finetuning a 🤗 Flax Transformers model for sequence classification on GLUE.""" import json import logging +import math import os import random import sys import time from dataclasses import dataclass, field -from itertools import chain from pathlib import Path from typing import Any, Callable, Dict, Optional, Tuple @@ -35,7 +35,7 @@ import jax.numpy as jnp import optax import transformers from flax import struct, traverse_util -from flax.jax_utils import replicate, unreplicate +from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -300,11 +300,15 @@ def glue_train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): def glue_eval_data_collator(dataset: Dataset, batch_size: int): - """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" - for i in range(len(dataset) // batch_size): - batch = dataset[i * batch_size : (i + 1) * batch_size] + """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop.""" + batch_idx = np.arange(len(dataset)) + + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} - batch = shard(batch) yield batch @@ -521,8 +525,9 @@ def main(): rng = jax.random.PRNGKey(training_args.seed) dropout_rngs = jax.random.split(rng, jax.local_device_count()) - train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() - eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() + train_batch_size = int(training_args.per_device_train_batch_size) * jax.local_device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() learning_rate_fn = create_learning_rate_fn( len(train_dataset), @@ -621,26 +626,15 @@ def main(): eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm( eval_loader, - total=len(eval_dataset) // eval_batch_size, + total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): labels = batch.pop("labels") - predictions = p_eval_step(state, batch) - metric.add_batch(predictions=chain(*predictions), references=chain(*labels)) - - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size - - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: np.array(v) for k, v in batch.items()} - - labels = batch.pop("labels") - predictions = eval_step(unreplicate(state), batch) - metric.add_batch(predictions=predictions, references=labels) + predictions = pad_shard_unpad(p_eval_step)( + state, batch, min_device_batch=per_device_eval_batch_size + ) + metric.add_batch(predictions=np.array(predictions), references=labels) eval_metric = metric.compute() diff --git a/examples/flax/token-classification/run_flax_ner.py b/examples/flax/token-classification/run_flax_ner.py index 682fc03b8b..062b10d2e3 100644 --- a/examples/flax/token-classification/run_flax_ner.py +++ b/examples/flax/token-classification/run_flax_ner.py @@ -16,6 +16,7 @@ """ Fine-tuning a 🤗 Flax Transformers model on token classification tasks (NER, POS, CHUNKS)""" import json import logging +import math import os import random import sys @@ -36,7 +37,7 @@ import jax.numpy as jnp import optax import transformers from flax import struct, traverse_util -from flax.jax_utils import replicate, unreplicate +from flax.jax_utils import pad_shard_unpad, replicate, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard from huggingface_hub import Repository @@ -351,11 +352,15 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int): def eval_data_collator(dataset: Dataset, batch_size: int): - """Returns batches of size `batch_size` from `eval dataset`, sharded over all local devices.""" - for i in range(len(dataset) // batch_size): - batch = dataset[i * batch_size : (i + 1) * batch_size] + """Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop.""" + batch_idx = np.arange(len(dataset)) + + steps_per_epoch = math.ceil(len(dataset) / batch_size) + batch_idx = np.array_split(batch_idx, steps_per_epoch) + + for idx in batch_idx: + batch = dataset[idx] batch = {k: np.array(v) for k, v in batch.items()} - batch = shard(batch) yield batch @@ -600,6 +605,7 @@ def main(): dropout_rngs = jax.random.split(rng, jax.local_device_count()) train_batch_size = training_args.per_device_train_batch_size * jax.local_device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) eval_batch_size = training_args.per_device_eval_batch_size * jax.local_device_count() learning_rate_fn = create_learning_rate_fn( @@ -728,34 +734,16 @@ def main(): # evaluate for batch in tqdm( eval_data_collator(eval_dataset, eval_batch_size), - total=len(eval_dataset) // eval_batch_size, + total=math.ceil(len(eval_dataset) / eval_batch_size), desc="Evaluating ...", position=2, ): labels = batch.pop("labels") - predictions = p_eval_step(state, batch) - predictions = np.array([pred for pred in chain(*predictions)]) - labels = np.array([label for label in chain(*labels)]) - labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 - preds, refs = get_labels(predictions, labels) - metric.add_batch( - predictions=preds, - references=refs, + predictions = pad_shard_unpad(p_eval_step)( + state, batch, min_device_batch=per_device_eval_batch_size ) - - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size - - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: np.array(v) for k, v in batch.items()} - - labels = batch.pop("labels") - predictions = eval_step(unreplicate(state), batch) - labels = np.array(labels) - labels[np.array(batch["attention_mask"]) == 0] = -100 + predictions = np.array(predictions) + labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch( predictions=preds, @@ -791,28 +779,12 @@ def main(): eval_loader = eval_data_collator(eval_dataset, eval_batch_size) for batch in tqdm(eval_loader, total=len(eval_dataset) // eval_batch_size, desc="Evaluating ...", position=2): labels = batch.pop("labels") - predictions = p_eval_step(state, batch) - predictions = np.array([pred for pred in chain(*predictions)]) - labels = np.array([label for label in chain(*labels)]) + predictions = pad_shard_unpad(p_eval_step)(state, batch, min_device_batch=per_device_eval_batch_size) + predictions = np.array(predictions) labels[np.array(chain(*batch["attention_mask"])) == 0] = -100 preds, refs = get_labels(predictions, labels) metric.add_batch(predictions=preds, references=refs) - # evaluate also on leftover examples (not divisible by batch_size) - num_leftover_samples = len(eval_dataset) % eval_batch_size - - # make sure leftover batch is evaluated on one device - if num_leftover_samples > 0 and jax.process_index() == 0: - # take leftover samples - batch = eval_dataset[-num_leftover_samples:] - batch = {k: np.array(v) for k, v in batch.items()} - - labels = np.array(batch.pop("labels")) - predictions = eval_step(unreplicate(state), batch) - labels[np.array(batch["attention_mask"]) == 0] = -100 - preds, refs = get_labels(predictions, labels) - metric.add_batch(predictions=preds, references=refs) - eval_metrics = compute_metrics() if jax.process_index() == 0: diff --git a/examples/flax/vision/run_image_classification.py b/examples/flax/vision/run_image_classification.py index 2883f7a38a..305dd3ac20 100644 --- a/examples/flax/vision/run_image_classification.py +++ b/examples/flax/vision/run_image_classification.py @@ -40,7 +40,7 @@ import jax.numpy as jnp import optax import transformers from flax import jax_utils -from flax.jax_utils import unreplicate +from flax.jax_utils import pad_shard_unpad, unreplicate from flax.training import train_state from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key from huggingface_hub import Repository @@ -368,7 +368,8 @@ def main(): # Store some constant num_epochs = int(training_args.num_train_epochs) train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count() - eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count() + per_device_eval_batch_size = int(training_args.per_device_eval_batch_size) + eval_batch_size = per_device_eval_batch_size * jax.device_count() steps_per_epoch = len(train_dataset) // train_batch_size total_train_steps = steps_per_epoch * num_epochs @@ -398,7 +399,7 @@ def main(): shuffle=False, num_workers=data_args.preprocessing_num_workers, persistent_workers=True, - drop_last=True, + drop_last=False, collate_fn=collate_fn, ) @@ -532,8 +533,9 @@ def main(): eval_step_progress_bar = tqdm(total=eval_steps, desc="Evaluating...", position=2, leave=False) for batch in eval_loader: # Model forward - batch = shard(batch) - metrics = p_eval_step(state.params, batch) + metrics = pad_shard_unpad(p_eval_step, static_return=True)( + state.params, batch, min_device_batch=per_device_eval_batch_size + ) eval_metrics.append(metrics) eval_step_progress_bar.update(1)