get_activation('relu') provides a simple mapping from strings i… (#2807)
* activations.py contains a mapping from string to activation function * resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
parent
f54a5bd37f
commit
ef74b0f07a
|
@ -0,0 +1,48 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def _gelu_python(x):
|
||||
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
This is now written in C in torch.nn.functional
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
gelu = getattr(F, "gelu", _gelu_python)
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
ACT2FN = {
|
||||
"relu": F.relu,
|
||||
"swish": swish,
|
||||
"gelu": gelu,
|
||||
"tanh": F.tanh,
|
||||
"gelu_new": gelu_new,
|
||||
}
|
||||
|
||||
|
||||
def get_activation(activation_string):
|
||||
if activation_string in ACT2FN:
|
||||
return ACT2FN[activation_string]
|
||||
else:
|
||||
raise KeyError(
|
||||
"function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
|
||||
activation_string, list(ACT2FN.keys())
|
||||
)
|
||||
)
|
|
@ -24,6 +24,7 @@ import torch
|
|||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
|
||||
from .activations import gelu, gelu_new, swish
|
||||
from .configuration_bert import BertConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
|
@ -129,26 +130,6 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
|
||||
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
|
||||
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def gelu_new(x):
|
||||
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def mish(x):
|
||||
return x * torch.tanh(nn.functional.softplus(x))
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import gelu
|
||||
from .configuration_distilbert import DistilBertConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
|
@ -47,8 +48,6 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||
|
||||
|
||||
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def create_sinusoidal_embeddings(n_pos, dim, out):
|
||||
|
|
|
@ -24,6 +24,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import gelu_new
|
||||
from .configuration_gpt2 import GPT2Config
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
|
@ -95,10 +96,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
|||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
super().__init__()
|
||||
|
@ -206,7 +203,7 @@ class MLP(nn.Module):
|
|||
nx = config.n_embd
|
||||
self.c_fc = Conv1D(n_state, nx)
|
||||
self.c_proj = Conv1D(nx, n_state)
|
||||
self.act = gelu
|
||||
self.act = gelu_new
|
||||
self.dropout = nn.Dropout(config.resid_pdrop)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -25,6 +25,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
|
||||
from .activations import gelu_new, swish
|
||||
from .configuration_openai import OpenAIGPTConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
|
||||
|
@ -114,15 +115,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
|
|||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
|
||||
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new}
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
|
|
@ -18,12 +18,14 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
import typing
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .activations import get_activation
|
||||
from .configuration_utils import PretrainedConfig
|
||||
from .file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
|
@ -1378,15 +1380,15 @@ class SequenceSummary(nn.Module):
|
|||
- 'attn' => Not implemented now, use multi-head attention
|
||||
summary_use_proj: Add a projection after the vector extraction
|
||||
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
||||
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
||||
summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
|
||||
summary_first_dropout: Add a dropout before the projection and activation
|
||||
summary_last_dropout: Add a dropout after the projection and activation
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last"
|
||||
self.summary_type = getattr(config, "summary_type", "last")
|
||||
if self.summary_type == "attn":
|
||||
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
||||
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
||||
|
@ -1401,9 +1403,10 @@ class SequenceSummary(nn.Module):
|
|||
num_classes = config.hidden_size
|
||||
self.summary = nn.Linear(config.hidden_size, num_classes)
|
||||
|
||||
self.activation = Identity()
|
||||
if hasattr(config, "summary_activation") and config.summary_activation == "tanh":
|
||||
self.activation = nn.Tanh()
|
||||
activation_string = getattr(config, "summary_activation", None)
|
||||
self.activation = (
|
||||
get_activation(activation_string) if activation_string else Identity()
|
||||
) # type: typing.Callable
|
||||
|
||||
self.first_dropout = Identity()
|
||||
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
|
||||
|
|
|
@ -26,6 +26,7 @@ from torch import nn
|
|||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .activations import gelu
|
||||
from .configuration_xlm import XLMConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer
|
||||
|
@ -55,17 +56,6 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
|
|||
out.requires_grad = False
|
||||
|
||||
|
||||
def gelu(x):
|
||||
"""
|
||||
GELU activation
|
||||
https://arxiv.org/abs/1606.08415
|
||||
https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
|
||||
https://github.com/huggingface/transformers/blob/master/modeling.py
|
||||
"""
|
||||
# return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
|
||||
|
||||
|
||||
def get_masks(slen, lengths, causal, padding_mask=None):
|
||||
"""
|
||||
Generate hidden states mask, and optionally an attention mask.
|
||||
|
|
|
@ -18,13 +18,13 @@
|
|||
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .activations import gelu_new, swish
|
||||
from .configuration_xlnet import XLNetConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
|
||||
|
@ -183,20 +183,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
|
|||
return model
|
||||
|
||||
|
||||
def gelu(x):
|
||||
""" Implementation of the gelu activation function.
|
||||
XLNet is using OpenAI GPT's gelu (not exactly the same as BERT)
|
||||
Also see https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
cdf = 0.5 * (1.0 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
|
||||
return x * cdf
|
||||
|
||||
|
||||
def swish(x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
ACT2FN = {"gelu": gelu_new, "relu": torch.nn.functional.relu, "swish": swish}
|
||||
|
||||
|
||||
XLNetLayerNorm = nn.LayerNorm
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from transformers.activations import _gelu_python, get_activation, gelu_new
|
||||
import torch
|
||||
|
||||
|
||||
@require_torch
|
||||
class TestActivations(unittest.TestCase):
|
||||
def test_gelu_versions(self):
|
||||
x = torch.Tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
|
||||
torch_builtin = get_activation("gelu")
|
||||
self.assertTrue(torch.eq(_gelu_python(x), torch_builtin(x)).all().item())
|
||||
self.assertFalse(torch.eq(_gelu_python(x), gelu_new(x)).all().item())
|
||||
|
||||
def test_get_activation(self):
|
||||
get_activation("swish")
|
||||
get_activation("relu")
|
||||
get_activation("tanh")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation("bogus")
|
||||
with self.assertRaises(KeyError):
|
||||
get_activation(None)
|
Loading…
Reference in New Issue