Fix mixed precision in TF models (#9163)
* Fix Gelu precision * Fix gelu_fast * Naming * Fix usage and apply style * add TF gelu approximate version * add TF gelu approximate version * add TF gelu approximate version * Apply style * Fix albert * Remove the usage of the Activation layer
This commit is contained in:
parent
248fa1ae72
commit
3f290e6c84
|
@ -15,9 +15,10 @@
|
|||
import math
|
||||
|
||||
import tensorflow as tf
|
||||
from packaging import version
|
||||
|
||||
|
||||
def gelu(x):
|
||||
def _gelu(x):
|
||||
"""
|
||||
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
|
||||
initially created. For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
|
@ -25,12 +26,12 @@ def gelu(x):
|
|||
https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
x = tf.convert_to_tensor(x)
|
||||
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
|
||||
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.cast(tf.sqrt(2.0), x.dtype)))
|
||||
|
||||
return x * cdf
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
def _gelu_new(x):
|
||||
"""
|
||||
Gaussian Error Linear Unit. This is a smoother version of the GELU. Original paper: https://arxiv.org/abs/1606.0841
|
||||
|
||||
|
@ -56,21 +57,33 @@ def mish(x):
|
|||
|
||||
def gelu_fast(x):
|
||||
x = tf.convert_to_tensor(x)
|
||||
coeff1 = tf.cast(7978845608, x.dtype)
|
||||
coeff1 = tf.cast(0.7978845608, x.dtype)
|
||||
coeff2 = tf.cast(0.044715, x.dtype)
|
||||
|
||||
return 0.5 * x * (1.0 + tf.tanh(x * coeff2 * (1.0 + coeff1 * x * x)))
|
||||
|
||||
|
||||
if version.parse(tf.version.VERSION) >= version.parse("2.4"):
|
||||
|
||||
def approximate_gelu_wrap(x):
|
||||
return tf.keras.activations.gelu(x, approximate=True)
|
||||
|
||||
gelu = tf.keras.activations.gelu
|
||||
gelu_new = approximate_gelu_wrap
|
||||
else:
|
||||
gelu = _gelu
|
||||
gelu_new = _gelu_new
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"gelu": tf.keras.layers.Activation(gelu),
|
||||
"gelu": gelu,
|
||||
"relu": tf.keras.activations.relu,
|
||||
"swish": tf.keras.activations.swish,
|
||||
"silu": tf.keras.activations.swish,
|
||||
"gelu_new": tf.keras.layers.Activation(gelu_new),
|
||||
"mish": tf.keras.layers.Activation(mish),
|
||||
"gelu_new": gelu_new,
|
||||
"mish": mish,
|
||||
"tanh": tf.keras.activations.tanh,
|
||||
"gelu_fast": tf.keras.layers.Activation(gelu_fast),
|
||||
"gelu_fast": gelu_fast,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -542,7 +542,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.activation(inputs=hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.LayerNorm(inputs=hidden_states)
|
||||
seq_length = shape_list(tensor=hidden_states)[1]
|
||||
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.embedding_size])
|
||||
|
|
|
@ -428,7 +428,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -327,7 +327,7 @@ class TFElectraIntermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -709,7 +709,7 @@ class TFLongformerIntermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -388,7 +388,7 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -448,7 +448,7 @@ class TFRobertaIntermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -382,7 +382,7 @@ class TF{{cookiecutter.camelcase_modelname}}Intermediate(tf.keras.layers.Layer):
|
|||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
Loading…
Reference in New Issue