174 lines
9.4 KiB
Markdown
174 lines
9.4 KiB
Markdown
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||
the License. You may obtain a copy of the License at
|
||
|
||
http://www.apache.org/licenses/LICENSE-2.0
|
||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||
specific language governing permissions and limitations under the License.
|
||
|
||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||
rendered properly in your Markdown viewer.
|
||
|
||
-->
|
||
|
||
# XLA Integration for TensorFlow Models
|
||
|
||
[[open-in-colab]]
|
||
|
||
Accelerated Linear Algebra, dubbed XLA, is a compiler for accelerating the runtime of TensorFlow Models. From the [official documentation](https://www.tensorflow.org/xla):
|
||
|
||
XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes.
|
||
|
||
Using XLA in TensorFlow is simple – it comes packaged inside the `tensorflow` library, and it can be triggered with the `jit_compile` argument in any graph-creating function such as [`tf.function`](https://www.tensorflow.org/guide/intro_to_graphs). When using Keras methods like `fit()` and `predict()`, you can enable XLA simply by passing the `jit_compile` argument to `model.compile()`. However, XLA is not limited to these methods - it can also be used to accelerate any arbitrary `tf.function`.
|
||
|
||
Several TensorFlow methods in 🤗 Transformers have been rewritten to be XLA-compatible, including text generation for models such as [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2), [T5](https://huggingface.co/docs/transformers/model_doc/t5) and [OPT](https://huggingface.co/docs/transformers/model_doc/opt), as well as speech processing for models such as [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper).
|
||
|
||
While the exact amount of speed-up is very much model-dependent, for TensorFlow text generation models inside 🤗 Transformers, we noticed a speed-up of ~100x. This document will explain how you can use XLA for these models to get the maximum amount of performance. We’ll also provide links to additional resources if you’re interested to learn more about the benchmarks and our design philosophy behind the XLA integration.
|
||
|
||
## Running TF functions with XLA
|
||
|
||
Let us consider the following model in TensorFlow:
|
||
|
||
```py
|
||
import tensorflow as tf
|
||
|
||
model = tf.keras.Sequential(
|
||
[tf.keras.layers.Dense(10, input_shape=(10,), activation="relu"), tf.keras.layers.Dense(5, activation="softmax")]
|
||
)
|
||
```
|
||
|
||
The above model accepts inputs having a dimension of `(10, )`. We can use the model for running a forward pass like so:
|
||
|
||
```py
|
||
# Generate random inputs for the model.
|
||
batch_size = 16
|
||
input_vector_dim = 10
|
||
random_inputs = tf.random.normal((batch_size, input_vector_dim))
|
||
|
||
# Run a forward pass.
|
||
_ = model(random_inputs)
|
||
```
|
||
|
||
In order to run the forward pass with an XLA-compiled function, we’d need to do:
|
||
|
||
```py
|
||
xla_fn = tf.function(model, jit_compile=True)
|
||
_ = xla_fn(random_inputs)
|
||
```
|
||
|
||
The default `call()` function of the `model` is used for compiling the XLA graph. But if there’s any other model function you want to compile into XLA that’s also possible with:
|
||
|
||
```py
|
||
my_xla_fn = tf.function(model.my_xla_fn, jit_compile=True)
|
||
```
|
||
|
||
## Running a TF text generation model with XLA from 🤗 Transformers
|
||
|
||
To enable XLA-accelerated generation within 🤗 Transformers, you need to have a recent version of `transformers` installed. You can install it by running:
|
||
|
||
```bash
|
||
pip install transformers --upgrade
|
||
```
|
||
|
||
And then you can run the following code:
|
||
|
||
```py
|
||
import tensorflow as tf
|
||
from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
||
|
||
# Will error if the minimal version of Transformers is not installed.
|
||
from transformers.utils import check_min_version
|
||
|
||
check_min_version("4.21.0")
|
||
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
|
||
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||
input_string = ["TensorFlow is"]
|
||
|
||
# One line to create an XLA generation function
|
||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||
|
||
tokenized_input = tokenizer(input_string, return_tensors="tf")
|
||
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
|
||
|
||
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||
print(f"Generated -- {decoded_text}")
|
||
# Generated -- TensorFlow is an open-source, open-source, distributed-source application # framework for the
|
||
```
|
||
|
||
As you can notice, enabling XLA on `generate()` is just a single line of code. The rest of the code remains unchanged. However, there are a couple of gotchas in the above code snippet that are specific to XLA. You need to be aware of those to realize the speed-ups that XLA can bring in. We discuss these in the following section.
|
||
|
||
## Gotchas to be aware of
|
||
|
||
When you are executing an XLA-enabled function (like `xla_generate()` above) for the first time, it will internally try to infer the computation graph, which is time-consuming. This process is known as [“tracing”](https://www.tensorflow.org/guide/intro_to_graphs#when_is_a_function_tracing).
|
||
|
||
You might notice that the generation time is not fast. Successive calls of `xla_generate()` (or any other XLA-enabled function) won’t have to infer the computation graph, given the inputs to the function follow the same shape with which the computation graph was initially built. While this is not a problem for modalities with fixed input shapes (e.g., images), you must pay attention if you are working with variable input shape modalities (e.g., text).
|
||
|
||
To ensure `xla_generate()` always operates with the same input shapes, you can specify the `padding` arguments when calling the tokenizer.
|
||
|
||
```py
|
||
import tensorflow as tf
|
||
from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
|
||
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||
input_string = ["TensorFlow is"]
|
||
|
||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||
|
||
# Here, we call the tokenizer with padding options.
|
||
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
|
||
|
||
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
|
||
decoded_text = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
|
||
print(f"Generated -- {decoded_text}")
|
||
```
|
||
|
||
This way, you can ensure that the inputs to `xla_generate()` will always receive inputs with the shape it was traced with and thus leading to speed-ups in the generation time. You can verify this with the code below:
|
||
|
||
```py
|
||
import time
|
||
import tensorflow as tf
|
||
from transformers import AutoTokenizer, TFAutoModelForCausalLM
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2", padding_side="left", pad_token="</s>")
|
||
model = TFAutoModelForCausalLM.from_pretrained("openai-community/gpt2")
|
||
|
||
xla_generate = tf.function(model.generate, jit_compile=True)
|
||
|
||
for input_string in ["TensorFlow is", "TensorFlow is a", "TFLite is a"]:
|
||
tokenized_input = tokenizer(input_string, pad_to_multiple_of=8, padding=True, return_tensors="tf")
|
||
start = time.time_ns()
|
||
generated_tokens = xla_generate(**tokenized_input, num_beams=2)
|
||
end = time.time_ns()
|
||
print(f"Execution time -- {(end - start) / 1e6:.1f} ms\n")
|
||
```
|
||
|
||
On a Tesla T4 GPU, you can expect the outputs like so:
|
||
|
||
```bash
|
||
Execution time -- 30819.6 ms
|
||
|
||
Execution time -- 79.0 ms
|
||
|
||
Execution time -- 78.9 ms
|
||
```
|
||
The first call to `xla_generate()` is time-consuming because of tracing, but the successive calls are orders of magnitude faster. Keep in mind that any change in the generation options at any point with trigger re-tracing and thus leading to slow-downs in the generation time.
|
||
|
||
We didn’t cover all the text generation options 🤗 Transformers provides in this document. We encourage you to read the documentation for advanced use cases.
|
||
|
||
## Additional Resources
|
||
|
||
Here, we leave you with some additional resources if you want to delve deeper into XLA in 🤗 Transformers and in general.
|
||
|
||
* [This Colab Notebook](https://colab.research.google.com/github/huggingface/blog/blob/main/notebooks/91_tf_xla_generate.ipynb) provides an interactive demonstration if you want to fiddle with the XLA-compatible encoder-decoder (like [T5](https://huggingface.co/docs/transformers/model_doc/t5)) and decoder-only (like [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)) text generation models.
|
||
* [This blog post](https://huggingface.co/blog/tf-xla-generate) provides an overview of the comparison benchmarks for XLA-compatible models along with a friendly introduction to XLA in TensorFlow.
|
||
* [This blog post](https://blog.tensorflow.org/2022/11/how-hugging-face-improved-text-generation-performance-with-xla.html) discusses our design philosophy behind adding XLA support to the TensorFlow models in 🤗 Transformers.
|
||
* Recommended posts for learning more about XLA and TensorFlow graphs in general:
|
||
* [XLA: Optimizing Compiler for Machine Learning](https://www.tensorflow.org/xla)
|
||
* [Introduction to graphs and tf.function](https://www.tensorflow.org/guide/intro_to_graphs)
|
||
* [Better performance with tf.function](https://www.tensorflow.org/guide/function) |