Replace swish with silu (#8166)

* Replace swish with silu

* revert nn.silu to nn.swish due to older version

* simplify optimized silu conditional and fix format

* Update activations.py

* Update activations_tf.py

* Update modeling_flax_utils.py

* Update modeling_openai.py

* add swish testcase

* add pytorch swish testcase

* Add more robust python version check

* more formatting fixes

Co-authored-by: TFUsers <TFUsers@gmail.com>
This commit is contained in:
TFUsers 2020-10-30 12:09:10 -07:00 committed by GitHub
parent cdc48ce92d
commit 00112c3539
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 56 additions and 37 deletions

View File

@ -2,6 +2,7 @@ import math
import torch
import torch.nn.functional as F
from packaging import version
from .utils import logging
@ -9,29 +10,25 @@ from .utils import logging
logger = logging.get_logger(__name__)
def swish(x):
return x * torch.sigmoid(x)
def _gelu_python(x):
"""
Original Implementation of the gelu activation function in Google Bert repo when initially created. For
information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in
torch.nn.functional Also see https://arxiv.org/abs/1606.08415
torch.nn.functional Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
"""
Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT). Also see
https://arxiv.org/abs/1606.08415
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
if torch.__version__ < "1.4.0":
if version.parse(torch.__version__) < version.parse("1.4"):
gelu = _gelu_python
else:
gelu = F.gelu
@ -41,6 +38,23 @@ def gelu_fast(x):
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * (1.0 + 0.044715 * x * x)))
def _silu_python(x):
"""
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
later.
"""
return x * torch.sigmoid(x)
if version.parse(torch.__version__) < version.parse("1.7"):
silu = _silu_python
else:
silu = F.silu
def mish(x):
return x * torch.tanh(torch.nn.functional.softplus(x))
@ -51,7 +65,8 @@ def linear_act(x):
ACT2FN = {
"relu": F.relu,
"swish": swish,
"silu": silu,
"swish": silu,
"gelu": gelu,
"tanh": torch.tanh,
"gelu_new": gelu_new,

View File

@ -52,6 +52,7 @@ ACT2FN = {
"gelu": tf.keras.layers.Activation(gelu),
"relu": tf.keras.activations.relu,
"swish": tf.keras.activations.swish,
"silu": tf.keras.activations.swish,
"gelu_new": tf.keras.layers.Activation(gelu_new),
"mish": tf.keras.layers.Activation(mish),
"tanh": tf.keras.activations.tanh,

View File

@ -61,7 +61,7 @@ class AlbertConfig(PretrainedConfig):
The number of inner repetition of attention and ffn.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu_new"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0):

View File

@ -59,7 +59,7 @@ class BartConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -74,7 +74,7 @@ class BertConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -40,7 +40,7 @@ class BertGenerationConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often called feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -56,7 +56,7 @@ class BlenderbotConfig(BartConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -52,7 +52,7 @@ class DebertaConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"`, :obj:`"gelu"`, :obj:`"tanh"`, :obj:`"gelu_fast"`,
:obj:`"mish"`, :obj:`"linear"`, :obj:`"sigmoid"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.

View File

@ -66,7 +66,7 @@ class DistilBertConfig(PretrainedConfig):
The dropout ratio for the attention probabilities.
activation (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
initializer_range (:obj:`float`, `optional`, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
qa_dropout (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -55,7 +55,7 @@ class DPRConfig(PretrainedConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -60,7 +60,7 @@ class ElectraConfig(PretrainedConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -71,7 +71,7 @@ class FSMTConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"relu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -66,7 +66,7 @@ class FunnelConfig(PretrainedConfig):
Inner dimension in the feed-forward blocks.
hidden_act (:obj:`str` or :obj:`callable`, `optional`, defaults to :obj:`"gelu_new"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -60,7 +60,7 @@ class GPT2Config(PretrainedConfig):
n_inner (:obj:`int`, `optional`, defaults to None):
Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd
activation_function (:obj:`str`, `optional`, defaults to :obj:`"gelu"`):
Activation function, to be selected in the list :obj:`["relu", "swish", "gelu", "tanh", "gelu_new"]`.
Activation function, to be selected in the list :obj:`["relu", "silu", "gelu", "tanh", "gelu_new"]`.
resid_pdrop (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (:obj:`int`, `optional`, defaults to 0.1):

View File

@ -52,7 +52,7 @@ class LayoutLMConfig(BertConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -55,7 +55,7 @@ class LxmertConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -50,7 +50,7 @@ class MarianConfig(BartConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -55,7 +55,7 @@ class MBartConfig(BartConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -48,7 +48,7 @@ class MobileBertConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"relu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.0):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -54,7 +54,7 @@ class OpenAIGPTConfig(PretrainedConfig):
Number of attention heads for each attention layer in the Transformer encoder.
afn (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
resid_pdrop (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
embd_pdrop (:obj:`int`, `optional`, defaults to 0.1):

View File

@ -94,7 +94,7 @@ class PegasusConfig(BartConfig):
Dimensionality of the "intermediate" (i.e., feed-forward) layer in decoder.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
dropout (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_dropout (:obj:`float`, `optional`, defaults to 0.0):

View File

@ -39,7 +39,7 @@ class ProphetNetConfig(PretrainedConfig):
The dropout ratio for activations inside the fully connected layer.
activation_function (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
vocab_size (:obj:`int`, `optional`, defaults to 30522):
Vocabulary size of the ProphetNET model. Defines the number of different tokens that can be represented by
the :obj:`inputs_ids` passed when calling :class:`~transformers.ProphetNetModel`.

View File

@ -80,7 +80,7 @@ class ReformerConfig(PretrainedConfig):
:obj:`None` to ensure fully random rotations in local sensitive hashing scheme.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"relu"`):
The non-linear activation function (function or string) in the feed forward layer in the residual attention
block. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
block. If string, :obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.05):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
hidden_size (:obj:`int`, `optional`, defaults to 256):

View File

@ -49,7 +49,7 @@ class RetriBertConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`function`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -50,7 +50,7 @@ class SqueezeBertConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
hidden_act (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the encoder and pooler. If string,
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"gelu"`, :obj:`"relu"`, :obj:`"silu"` and :obj:`"gelu_new"` are supported.
hidden_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob (:obj:`float`, `optional`, defaults to 0.1):

View File

@ -54,7 +54,7 @@ class XLNetConfig(PretrainedConfig):
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
ff_activation (:obj:`str` or :obj:`Callable`, `optional`, defaults to :obj:`"gelu"`):
The non-linear activation function (function or string) in the If string, :obj:`"gelu"`, :obj:`"relu"`,
:obj:`"swish"` and :obj:`"gelu_new"` are supported.
:obj:`"silu"` and :obj:`"gelu_new"` are supported.
untie_r (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to untie relative position biases
attn_type (:obj:`str`, `optional`, defaults to :obj:`"bi"`):

View File

@ -53,6 +53,7 @@ def gelu(x):
ACT2FN = {
"gelu": nn.gelu,
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": gelu,
}

View File

@ -27,7 +27,7 @@ import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu_new, swish
from .activations import gelu_new, silu
from .configuration_openai import OpenAIGPTConfig
from .file_utils import (
ModelOutput,
@ -139,7 +139,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
return model
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new}
ACT_FNS = {"relu": nn.ReLU, "silu": silu, "gelu": gelu_new, "swish": silu}
class Attention(nn.Module):

View File

@ -20,6 +20,7 @@ class TestActivations(unittest.TestCase):
def test_get_activation(self):
get_activation("swish")
get_activation("silu")
get_activation("relu")
get_activation("tanh")
get_activation("gelu_new")

View File

@ -12,6 +12,7 @@ if is_tf_available():
class TestTFActivations(unittest.TestCase):
def test_get_activation(self):
get_tf_activation("swish")
get_tf_activation("silu")
get_tf_activation("gelu")
get_tf_activation("relu")
get_tf_activation("tanh")