fix torchscript docs (#6740)

This commit is contained in:
Patrick von Platen 2020-08-26 10:51:56 +02:00 committed by GitHub
parent 64c7c2bc15
commit fa8ee8e855
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 10 deletions

View File

@ -130,13 +130,12 @@ Pytorch's two modules `JIT and TRACE <https://pytorch.org/docs/stable/jit.html>`
their model to be re-used in other programs, such as efficiency-oriented C++ programs.
We have provided an interface that allows the export of 🤗 Transformers models to TorchScript so that they can
be reused in a different environment than a Pytorch-based python program. Here we explain how to use our models so that
they can be exported, and what to be mindful of when using these models with TorchScript.
be reused in a different environment than a Pytorch-based python program. Here we explain how to export and use our models using TorchScript.
Exporting a model needs two things:
Exporting a model requires two things:
* dummy inputs to execute a model forward pass.
* the model needs to be instantiated with the ``torchscript`` flag.
* a forward pass with dummy inputs.
* model instantiation with the ``torchscript`` flag.
These necessities imply several things developers should be careful about. These are detailed below.
@ -147,8 +146,8 @@ Implications
TorchScript flag and tied weights
------------------------------------------------
This flag is necessary because most of the language models in this repository have tied weights between their
``Embedding`` layer and their ``Decoding`` layer. TorchScript does not allow the export of models that have tied weights,
it is therefore necessary to untie the weights beforehand.
``Embedding`` layer and their ``Decoding`` layer. TorchScript does not allow the export of models that have tied weights, therefore
it is necessary to untie and clone the weights beforehand.
This implies that models instantiated with the ``torchscript`` flag have their ``Embedding`` layer and ``Decoding`` layer
separate, which means that they should not be trained down the line. Training would de-synchronize the two layers,
@ -181,7 +180,7 @@ when exporting varying sequence-length models.
Using TorchScript in Python
-------------------------------------------------
Below are examples of using the Python to save, load models as well as how to use the trace for inference.
Below is an example, showing how to save, load models as well as how to use the trace for inference.
Saving a model
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -237,10 +236,10 @@ We are re-using the previously initialised ``dummy_input``.
.. code-block:: python
loaded_model = torch.jit.load("traced_model.pt")
loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(dummy_input)
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
Using a traced model for inference
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^