transformers/examples/flax
Patrick von Platen 13fefdf340
Update README.md
cc @patil-suraj
2021-07-20 13:51:15 +02:00
..
language-modeling Update README.md 2021-07-20 13:51:15 +02:00
summarization added test file (#12630) 2021-07-12 12:15:14 +05:30
text-classification [Flax] Adapt flax examples to include `push_to_hub` (#12391) 2021-06-28 19:23:35 +01:00
vision [Flax] ViT training example (#12300) 2021-07-05 18:23:03 +05:30
README.md [examples/Flax] move the examples table up (#12341) 2021-06-24 16:03:37 +05:30

README.md

JAX/Flax Examples

This folder contains actively maintained examples of 🤗 Transformers using the JAX/Flax backend. Porting models and examples to JAX/Flax is an ongoing effort, and more will be added in the coming months. In particular, these examples are all designed to run fast on Cloud TPUs, and we include step-by-step guides to getting started with Cloud TPU.

NOTE: Currently, there is no "Trainer" abstraction for JAX/Flax -- all examples contain an explicit training loop.

The following table lists all of our examples on how to use 🤗 Transformers with the JAX/Flax backend:

  • with information about the model and dataset used,
  • whether or not they leverage the 🤗 Datasets library,
  • links to Colab notebooks to walk through the scripts and run them easily.
Task Example model Example dataset 🤗 Datasets Colab
causal-language-modeling GPT2 OSCAR Open In Colab
masked-language-modeling RoBERTa OSCAR Open In Colab
text-classification BERT GLUE Open In Colab

Intro: JAX and Flax

JAX is a numerical computation library that exposes a NumPy-like API with tracing capabilities. With JAX's jit, you can trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. JAX supports additional transformations such as grad (for arbitrary gradients), pmap (for parallelizing computation on multiple devices), remat (for gradient checkpointing), vmap (automatic efficient vectorization), and pjit (for automatically sharded model parallelism). All JAX transformations compose arbitrarily with each other -- e.g., efficiently computing per-example gradients is simply vmap(grad(f)).

Flax builds on top of JAX with an ergonomic module abstraction using Python dataclasses that leads to concise and explicit code. Flax's "lifted" JAX transformations (e.g. vmap, remat) allow you to nest JAX transformation and modules in any way you wish. Flax is the most widely used JAX library, with 129 dependent projects as of May 2021. It is also the library underlying all of the official Cloud TPU JAX examples.

Running on Cloud TPU

All of our JAX/Flax models are designed to run efficiently on Google Cloud TPUs. Here is a guide for running JAX on Google Cloud TPU.

Each example README contains more details on the specific model and training procedure.

Supported models

Porting models from PyTorch to JAX/Flax is an ongoing effort. Feel free to reach out if you are interested in contributing a model in JAX/Flax -- we'll be adding a guide for porting models from PyTorch in the upcoming few weeks.

For a complete overview of models that are supported in JAX/Flax, please have a look at this table.

Over 3000 pretrained checkpoints are supported in JAX/Flax as of May 2021. Click here to see the full list on the 🤗 hub.