get_activation('relu') provides a simple mapping from strings i… (#2807)

* activations.py contains a mapping from string to activation function
* resolves some `gelu` vs `gelu_new` ambiguity
This commit is contained in:
Sam Shleifer 2020-02-13 08:28:33 -05:00 committed by GitHub
parent f54a5bd37f
commit ef74b0f07a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 94 additions and 68 deletions

View File

@ -0,0 +1,48 @@
import math
import torch
import torch.nn.functional as F
def swish(x):
return x * torch.sigmoid(x)
def _gelu_python(x):
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
This is now written in C in torch.nn.functional
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
gelu = getattr(F, "gelu", _gelu_python)
def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
ACT2FN = {
"relu": F.relu,
"swish": swish,
"gelu": gelu,
"tanh": F.tanh,
"gelu_new": gelu_new,
}
def get_activation(activation_string):
if activation_string in ACT2FN:
return ACT2FN[activation_string]
else:
raise KeyError(
"function {} not found in ACT2FN mapping {} or torch.nn.functional".format(
activation_string, list(ACT2FN.keys())
)
)

View File

@ -24,6 +24,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from .activations import gelu, gelu_new, swish
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, prune_linear_layer
@ -129,26 +130,6 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
return model return model
def gelu(x):
""" Original Implementation of the gelu activation function in Google Bert repo when initially created.
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):
0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
Also see https://arxiv.org/abs/1606.08415
"""
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
def gelu_new(x):
""" Implementation of the gelu activation function currently in Google Bert repo (identical to OpenAI GPT).
Also see https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
def mish(x): def mish(x):
return x * torch.tanh(nn.functional.softplus(x)) return x * torch.tanh(nn.functional.softplus(x))

View File

@ -27,6 +27,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import gelu
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, prune_linear_layer from .modeling_utils import PreTrainedModel, prune_linear_layer
@ -47,8 +48,6 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
# UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE # # UTILS AND BUILDING BLOCKS OF THE ARCHITECTURE #
def gelu(x):
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
def create_sinusoidal_embeddings(n_pos, dim, out): def create_sinusoidal_embeddings(n_pos, dim, out):

View File

@ -24,6 +24,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import gelu_new
from .configuration_gpt2 import GPT2Config from .configuration_gpt2 import GPT2Config
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
@ -95,10 +96,6 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model return model
def gelu(x):
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False):
super().__init__() super().__init__()
@ -206,7 +203,7 @@ class MLP(nn.Module):
nx = config.n_embd nx = config.n_embd
self.c_fc = Conv1D(n_state, nx) self.c_fc = Conv1D(n_state, nx)
self.c_proj = Conv1D(nx, n_state) self.c_proj = Conv1D(nx, n_state)
self.act = gelu self.act = gelu_new
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, x): def forward(self, x):

View File

@ -25,6 +25,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from .activations import gelu_new, swish
from .configuration_openai import OpenAIGPTConfig from .configuration_openai import OpenAIGPTConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer from .modeling_utils import Conv1D, PreTrainedModel, SequenceSummary, prune_conv1d_layer
@ -114,15 +115,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
return model return model
def gelu(x): ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu_new}
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
def swish(x):
return x * torch.sigmoid(x)
ACT_FNS = {"relu": nn.ReLU, "swish": swish, "gelu": gelu}
class Attention(nn.Module): class Attention(nn.Module):

View File

@ -18,12 +18,14 @@
import logging import logging
import os import os
import typing
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .file_utils import ( from .file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
@ -1378,15 +1380,15 @@ class SequenceSummary(nn.Module):
- 'attn' => Not implemented now, use multi-head attention - 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_activation: 'tanh' or another string => add an activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation summary_last_dropout: Add a dropout after the projection and activation
""" """
def __init__(self, config): def __init__(self, config: PretrainedConfig):
super().__init__() super().__init__()
self.summary_type = config.summary_type if hasattr(config, "summary_type") else "last" self.summary_type = getattr(config, "summary_type", "last")
if self.summary_type == "attn": if self.summary_type == "attn":
# We should use a standard multi-head attention module with absolute positional embedding for that. # We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
@ -1401,9 +1403,10 @@ class SequenceSummary(nn.Module):
num_classes = config.hidden_size num_classes = config.hidden_size
self.summary = nn.Linear(config.hidden_size, num_classes) self.summary = nn.Linear(config.hidden_size, num_classes)
self.activation = Identity() activation_string = getattr(config, "summary_activation", None)
if hasattr(config, "summary_activation") and config.summary_activation == "tanh": self.activation = (
self.activation = nn.Tanh() get_activation(activation_string) if activation_string else Identity()
) # type: typing.Callable
self.first_dropout = Identity() self.first_dropout = Identity()
if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:

View File

@ -26,6 +26,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import gelu
from .configuration_xlm import XLMConfig from .configuration_xlm import XLMConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer from .modeling_utils import PreTrainedModel, SequenceSummary, SQuADHead, prune_linear_layer
@ -55,17 +56,6 @@ def create_sinusoidal_embeddings(n_pos, dim, out):
out.requires_grad = False out.requires_grad = False
def gelu(x):
"""
GELU activation
https://arxiv.org/abs/1606.08415
https://github.com/huggingface/pytorch-openai-transformer-lm/blob/master/model_pytorch.py#L14
https://github.com/huggingface/transformers/blob/master/modeling.py
"""
# return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return 0.5 * x * (1.0 + torch.erf(x / math.sqrt(2.0)))
def get_masks(slen, lengths, causal, padding_mask=None): def get_masks(slen, lengths, causal, padding_mask=None):
""" """
Generate hidden states mask, and optionally an attention mask. Generate hidden states mask, and optionally an attention mask.

View File

@ -18,13 +18,13 @@
import logging import logging
import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from .activations import gelu_new, swish
from .configuration_xlnet import XLNetConfig from .configuration_xlnet import XLNetConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary from .modeling_utils import PoolerAnswerClass, PoolerEndLogits, PoolerStartLogits, PreTrainedModel, SequenceSummary
@ -183,20 +183,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
return model return model
def gelu(x): ACT2FN = {"gelu": gelu_new, "relu": torch.nn.functional.relu, "swish": swish}
""" Implementation of the gelu activation function.
XLNet is using OpenAI GPT's gelu (not exactly the same as BERT)
Also see https://arxiv.org/abs/1606.08415
"""
cdf = 0.5 * (1.0 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return x * cdf
def swish(x):
return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
XLNetLayerNorm = nn.LayerNorm XLNetLayerNorm = nn.LayerNorm

28
tests/test_activations.py Normal file
View File

@ -0,0 +1,28 @@
import unittest
from transformers import is_torch_available
from .utils import require_torch
if is_torch_available():
from transformers.activations import _gelu_python, get_activation, gelu_new
import torch
@require_torch
class TestActivations(unittest.TestCase):
def test_gelu_versions(self):
x = torch.Tensor([-100, -1, -0.1, 0, 0.1, 1.0, 100])
torch_builtin = get_activation("gelu")
self.assertTrue(torch.eq(_gelu_python(x), torch_builtin(x)).all().item())
self.assertFalse(torch.eq(_gelu_python(x), gelu_new(x)).all().item())
def test_get_activation(self):
get_activation("swish")
get_activation("relu")
get_activation("tanh")
with self.assertRaises(KeyError):
get_activation("bogus")
with self.assertRaises(KeyError):
get_activation(None)