Fix mixed precision in TF models (#9163)

* Fix Gelu precision

* Fix gelu_fast

* Naming

* Fix usage and apply style

* add TF gelu approximate version

* add TF gelu approximate version

* add TF gelu approximate version

* Apply style

* Fix albert

* Remove the usage of the Activation layer
This commit is contained in:
Julien Plu 2021-01-21 13:00:11 +01:00 committed by GitHub
parent 248fa1ae72
commit 3f290e6c84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 28 additions and 15 deletions

View File

@ -15,9 +15,10 @@
import math
import tensorflow as tf
from packaging import version
def gelu(x):
def _gelu(x):
"""
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
@ -25,12 +26,12 @@ def gelu(x):
https://arxiv.org/abs/1606.08415
"""
x = tf.convert_to_tensor(x)
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
return x * cdf
def gelu_new(x):
def _gelu_new(x):
"""
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
@ -56,21 +57,33 @@ def mish(x):
def gelu_fast(x):
x = tf.convert_to_tensor(x)
coeff1 = tf.cast(7978845608, x.dtype)
coeff1 = tf.cast(0.7978845608, x.dtype)
coeff2 = tf.cast(0.044715, x.dtype)
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
def approximate_gelu_wrap(x):
return tf.keras.activations.gelu(x, approximate=True)
gelu = tf.keras.activations.gelu
gelu_new = approximate_gelu_wrap
else:
gelu = _gelu
gelu_new = _gelu_new
ACT2FN = {
"gelu": tf.keras.layers.Activation(gelu),
"gelu": gelu,
"relu": tf.keras.activations.relu,
"swish": tf.keras.activations.swish,
"silu": tf.keras.activations.swish,
"gelu_new": tf.keras.layers.Activation(gelu_new),
"mish": tf.keras.layers.Activation(mish),
"gelu_new": gelu_new,
"mish": mish,
"tanh": tf.keras.activations.tanh,
"gelu_fast": tf.keras.layers.Activation(gelu_fast),
"gelu_fast": gelu_fast,
}

View File

@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.activation(inputs=hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.LayerNorm(inputs=hidden_states)
seq_length = shape_list(tensor=hidden_states)[1]
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])

View File

@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

View File

@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

View File

@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

View File

@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

View File

@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states

View File

@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
def call(self, hidden_states):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states