486 lines
25 KiB
Python
486 lines
25 KiB
Python
# coding=utf-8
|
|
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
|
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""TF general model utils."""
|
|
|
|
from __future__ import (absolute_import, division, print_function,
|
|
unicode_literals)
|
|
|
|
import logging
|
|
import os
|
|
|
|
import tensorflow as tf
|
|
|
|
from .configuration_utils import PretrainedConfig
|
|
from .file_utils import cached_path, WEIGHTS_NAME, TF_WEIGHTS_NAME, TF2_WEIGHTS_NAME
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TFPreTrainedModel(tf.keras.Model):
|
|
r""" Base class for all TF models.
|
|
|
|
:class:`~pytorch_transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
|
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
|
|
|
|
Class attributes (overridden by derived classes):
|
|
- ``config_class``: a class derived from :class:`~pytorch_transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
|
- ``pretrained_model_archive_map``: a python ``dict`` of with `short-cut-names` (string) as keys and `url` (string) of associated pretrained weights as values.
|
|
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
|
|
|
|
- ``model``: an instance of the relevant subclass of :class:`~pytorch_transformers.PreTrainedModel`,
|
|
- ``config``: an instance of the relevant subclass of :class:`~pytorch_transformers.PretrainedConfig`,
|
|
- ``path``: a path (string) to the TensorFlow checkpoint.
|
|
|
|
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
|
|
"""
|
|
config_class = None
|
|
pretrained_model_archive_map = {}
|
|
load_pt_weights = lambda model, config, path: None
|
|
base_model_prefix = ""
|
|
|
|
def __init__(self, config, *inputs, **kwargs):
|
|
super(TFPreTrainedModel, self).__init__(*inputs, **kwargs)
|
|
if not isinstance(config, PretrainedConfig):
|
|
raise ValueError(
|
|
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
|
|
"To create a model from a pretrained model use "
|
|
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
|
self.__class__.__name__, self.__class__.__name__
|
|
))
|
|
# Save config in model
|
|
self.config = config
|
|
|
|
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None):
|
|
""" Build a resized Embedding Variable from a provided token Embedding Module.
|
|
Increasing the size will add newly initialized vectors at the end
|
|
Reducing the size will remove vectors from the end
|
|
|
|
Args:
|
|
new_num_tokens: (`optional`) int
|
|
New number of tokens in the embedding matrix.
|
|
Increasing the size will add newly initialized vectors at the end
|
|
Reducing the size will remove vectors from the end
|
|
If not provided or None: return the provided token Embedding Module.
|
|
Return: ``tf.Variable``
|
|
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
|
|
"""
|
|
# if new_num_tokens is None:
|
|
# return old_embeddings
|
|
|
|
# old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
|
|
# if old_num_tokens == new_num_tokens:
|
|
# return old_embeddings
|
|
|
|
# # Build new embeddings
|
|
# new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
|
|
# new_embeddings.to(old_embeddings.weight.device)
|
|
|
|
# # initialize all new embeddings (in particular added tokens)
|
|
# self._init_weights(new_embeddings)
|
|
|
|
# # Copy word embeddings from the previous weights
|
|
# num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
|
|
# new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[:num_tokens_to_copy, :]
|
|
|
|
# return new_embeddings
|
|
|
|
def resize_token_embeddings(self, new_num_tokens=None):
|
|
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
|
|
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
|
|
|
|
Arguments:
|
|
|
|
new_num_tokens: (`optional`) int:
|
|
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
|
|
If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
|
|
|
|
Return: ``tf.Variable``
|
|
Pointer to the input tokens Embeddings Module of the model
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def prune_heads(self, heads_to_prune):
|
|
""" Prunes heads of the base model.
|
|
|
|
Arguments:
|
|
|
|
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def save_pretrained(self, save_directory):
|
|
""" Save a model and its configuration file to a directory, so that it
|
|
can be re-loaded using the `:func:`~pytorch_transformers.PreTrainedModel.from_pretrained`` class method.
|
|
"""
|
|
assert os.path.isdir(save_directory), "Saving path should be a directory where the model and configuration can be saved"
|
|
|
|
# Save configuration file
|
|
self.config.save_pretrained(save_directory)
|
|
|
|
# If we save using the predefined names, we can load using `from_pretrained`
|
|
output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME)
|
|
self.save_weights(output_model_file)
|
|
logger.info("Model weights saved in {}".format(output_model_file))
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
|
|
|
|
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
|
|
To train the model, you should first set it back in training mode with ``model.train()``
|
|
|
|
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
|
|
It is up to you to train those weights with a downstream fine-tuning task.
|
|
|
|
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
|
|
|
|
Parameters:
|
|
pretrained_model_name_or_path: either:
|
|
|
|
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
|
|
- a path to a `directory` containing model weights saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
|
|
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
|
|
|
|
model_args: (`optional`) Sequence of positional arguments:
|
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
|
|
|
|
config: (`optional`) instance of a class derived from :class:`~pytorch_transformers.PretrainedConfig`:
|
|
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
|
|
|
|
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
|
|
- the model was saved using :func:`~pytorch_transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
|
|
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
|
|
|
|
from_pt: (`optional`) boolean, default False:
|
|
Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument).
|
|
|
|
cache_dir: (`optional`) string:
|
|
Path to a directory in which a downloaded pre-trained model
|
|
configuration should be cached if the standard cache should not be used.
|
|
|
|
force_download: (`optional`) boolean, default False:
|
|
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
|
|
|
|
proxies: (`optional`) dict, default None:
|
|
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
|
|
The proxies are used on each request.
|
|
|
|
kwargs: (`optional`) Remaining dictionary of keyword arguments:
|
|
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
|
|
|
|
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
|
|
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~pytorch_transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
|
|
|
|
Examples::
|
|
|
|
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
|
|
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
|
|
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
|
|
assert model.config.output_attention == True
|
|
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
|
|
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
|
|
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config)
|
|
|
|
"""
|
|
config = kwargs.pop('config', None)
|
|
cache_dir = kwargs.pop('cache_dir', None)
|
|
from_pt = kwargs.pop('from_pt', False)
|
|
force_download = kwargs.pop('force_download', False)
|
|
proxies = kwargs.pop('proxies', None)
|
|
|
|
# Load config
|
|
if config is None:
|
|
config, model_kwargs = cls.config_class.from_pretrained(
|
|
pretrained_model_name_or_path, *model_args,
|
|
cache_dir=cache_dir, return_unused_kwargs=True,
|
|
force_download=force_download,
|
|
**kwargs
|
|
)
|
|
else:
|
|
model_kwargs = kwargs
|
|
|
|
# Load model
|
|
if pretrained_model_name_or_path is not None:
|
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
|
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
|
|
elif os.path.isdir(pretrained_model_name_or_path):
|
|
if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)):
|
|
# Load from a TF 2.0 checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)
|
|
elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
|
# Load from a PyTorch checkpoint
|
|
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
|
else:
|
|
raise EnvironmentError("Error no file named {} found in directory {} or `from_pt` set to False".format(
|
|
[WEIGHTS_NAME, TF2_WEIGHTS_NAME],
|
|
pretrained_model_name_or_path))
|
|
elif os.path.isfile(pretrained_model_name_or_path):
|
|
archive_file = pretrained_model_name_or_path
|
|
else:
|
|
raise EnvironmentError("Error file {} not found".format(pretrained_model_name_or_path))
|
|
|
|
# redirect to the cache, if necessary
|
|
try:
|
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
|
|
except EnvironmentError as e:
|
|
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
|
|
logger.error(
|
|
"Couldn't reach server at '{}' to download pretrained weights.".format(
|
|
archive_file))
|
|
else:
|
|
logger.error(
|
|
"Model name '{}' was not found in model name list ({}). "
|
|
"We assumed '{}' was a path or url but couldn't find any file "
|
|
"associated to this path or url.".format(
|
|
pretrained_model_name_or_path,
|
|
', '.join(cls.pretrained_model_archive_map.keys()),
|
|
archive_file))
|
|
raise e
|
|
if resolved_archive_file == archive_file:
|
|
logger.info("loading weights file {}".format(archive_file))
|
|
else:
|
|
logger.info("loading weights file {} from cache at {}".format(
|
|
archive_file, resolved_archive_file))
|
|
else:
|
|
resolved_archive_file = None
|
|
|
|
# Instantiate model.
|
|
model = cls(config, *model_args, **model_kwargs)
|
|
|
|
if from_pt:
|
|
# Load from a PyTorch checkpoint
|
|
return cls.load_pt_weights(model, resolved_archive_file)
|
|
|
|
inputs = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
|
|
ret = model(inputs, training=False) # build the network with dummy inputs
|
|
|
|
assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file)
|
|
# 'by_name' allow us to do transfer learning by skipping/adding layers
|
|
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
|
|
model.load_weights(resolved_archive_file, by_name=True)
|
|
|
|
ret = model(inputs, training=False) # Make sure restore ops are run
|
|
|
|
return model
|
|
|
|
class TFConv1D(tf.keras.layers.Layer):
|
|
def __init__(self, nf, nx, initializer_range=0.02, **kwargs):
|
|
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
|
|
Basically works like a Linear layer but the weights are transposed
|
|
"""
|
|
super(TFConv1D, self).__init__(**kwargs)
|
|
self.nf = nf
|
|
self.nx = nx
|
|
self.initializer_range = initializer_range
|
|
|
|
def build(self, input_shape):
|
|
self.weight = self.add_weight(
|
|
"weight",
|
|
shape=[self.nx, self.nf],
|
|
initializer=get_initializer(self.initializer_range))
|
|
self.bias = self.add_weight(
|
|
"bias",
|
|
shape=[1, self.nf],
|
|
initializer=tf.zeros_initializer())
|
|
|
|
def call(self, x):
|
|
bz, sl = shape_list(x)[:2]
|
|
|
|
x = tf.reshape(x, [-1, self.nx])
|
|
x = tf.matmul(x, self.weight) + self.bias
|
|
|
|
x = tf.reshape(x, [bz, sl, self.nf])
|
|
|
|
return x
|
|
|
|
|
|
class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|
"""Construct shared token embeddings.
|
|
"""
|
|
def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs):
|
|
super(TFSharedEmbeddings, self).__init__(**kwargs)
|
|
self.vocab_size = vocab_size
|
|
self.hidden_size = hidden_size
|
|
self.initializer_range = hidden_size**-0.5 if initializer_range is None else initializer_range
|
|
|
|
def build(self, input_shape):
|
|
"""Build shared word embedding layer
|
|
Shared weights logic adapted from
|
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
|
"""
|
|
self.weight = self.add_weight(
|
|
"weight",
|
|
shape=[self.vocab_size, self.hidden_size],
|
|
initializer=get_initializer(self.initializer_range))
|
|
super(TFSharedEmbeddings, self).build(input_shape)
|
|
|
|
def call(self, inputs, mode="embedding"):
|
|
"""Get token embeddings of inputs.
|
|
Args:
|
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
|
mode: string, a valid value is one of "embedding" and "linear".
|
|
Returns:
|
|
outputs: (1) If mode == "embedding", output embedding tensor, float32 with
|
|
shape [batch_size, length, embedding_size]; (2) mode == "linear", output
|
|
linear tensor, float32 with shape [batch_size, length, vocab_size].
|
|
Raises:
|
|
ValueError: if mode is not valid.
|
|
|
|
Shared weights logic adapted from
|
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
|
"""
|
|
if mode == "embedding":
|
|
return self._embedding(inputs)
|
|
elif mode == "linear":
|
|
return self._linear(inputs)
|
|
else:
|
|
raise ValueError("mode {} is not valid.".format(mode))
|
|
|
|
def _embedding(self, input_ids):
|
|
"""Applies embedding based on inputs tensor."""
|
|
return tf.gather(self.weight, input_ids)
|
|
|
|
def _linear(self, inputs):
|
|
"""Computes logits by running inputs through a linear layer.
|
|
Args:
|
|
inputs: A float32 tensor with shape [..., hidden_size]
|
|
Returns:
|
|
float32 tensor with shape [..., vocab_size].
|
|
"""
|
|
first_dims = shape_list(inputs)[:-1]
|
|
|
|
x = tf.reshape(inputs, [-1, self.hidden_size])
|
|
logits = tf.matmul(x, self.weight, transpose_b=True)
|
|
|
|
return tf.reshape(logits, first_dims + [self.vocab_size])
|
|
|
|
|
|
class TFSequenceSummary(tf.keras.layers.Layer):
|
|
r""" Compute a single vector summary of a sequence hidden states according to various possibilities:
|
|
Args of the config class:
|
|
summary_type:
|
|
- 'last' => [default] take the last token hidden state (like XLNet)
|
|
- 'first' => take the first token hidden state (like Bert)
|
|
- 'mean' => take the mean of all tokens hidden states
|
|
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
|
|
- 'attn' => Not implemented now, use multi-head attention
|
|
summary_use_proj: Add a projection after the vector extraction
|
|
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
|
|
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
|
|
summary_first_dropout: Add a dropout before the projection and activation
|
|
summary_last_dropout: Add a dropout after the projection and activation
|
|
"""
|
|
def __init__(self, config, initializer_range=0.02, **kwargs):
|
|
super(TFSequenceSummary, self).__init__(**kwargs)
|
|
|
|
self.summary_type = config.summary_type if hasattr(config, 'summary_use_proj') else 'last'
|
|
if self.summary_type == 'attn':
|
|
# We should use a standard multi-head attention module with absolute positional embedding for that.
|
|
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
|
|
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
|
|
raise NotImplementedError
|
|
|
|
self.summary = None
|
|
if hasattr(config, 'summary_use_proj') and config.summary_use_proj:
|
|
if hasattr(config, 'summary_proj_to_labels') and config.summary_proj_to_labels and config.num_labels > 0:
|
|
num_classes = config.num_labels
|
|
else:
|
|
num_classes = config.hidden_size
|
|
self.summary = tf.keras.layers.Dense(num_classes,
|
|
kernel_initializer=get_initializer(initializer_range),
|
|
name='summary')
|
|
|
|
self.activation = None
|
|
if hasattr(config, 'summary_activation') and config.summary_activation == 'tanh':
|
|
self.activation = tf.keras.activations.tanh
|
|
|
|
self.first_dropout = None
|
|
if hasattr(config, 'summary_first_dropout') and config.summary_first_dropout > 0:
|
|
self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout)
|
|
|
|
self.last_dropout = None
|
|
if hasattr(config, 'summary_last_dropout') and config.summary_last_dropout > 0:
|
|
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
|
|
|
def call(self, inputs, training=False):
|
|
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
|
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
|
|
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
|
|
if summary_type == 'cls_index' and cls_index is None:
|
|
we take the last token of the sequence as classification token
|
|
"""
|
|
if not isinstance(inputs, (dict, tuple, list)):
|
|
hidden_states = inputs
|
|
cls_index = None
|
|
elif isinstance(inputs, (tuple, list)):
|
|
hidden_states = inputs[0]
|
|
cls_index = inputs[1] if len(inputs) > 1 else None
|
|
assert len(inputs) <= 2, "Too many inputs."
|
|
else:
|
|
input_ids = inputs.get('input_ids')
|
|
cls_index = inputs.get('cls_index', None)
|
|
|
|
if self.summary_type == 'last':
|
|
output = hidden_states[:, -1]
|
|
elif self.summary_type == 'first':
|
|
output = hidden_states[:, 0]
|
|
elif self.summary_type == 'mean':
|
|
output = tf.mean(hidden_states, axis=1)
|
|
elif self.summary_type == 'cls_index':
|
|
hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims]
|
|
if cls_index is None:
|
|
cls_index = tf.fill(hidden_shape[:-2], hidden_shape[-2] - 1) # A tensor full of shape [batch] or [batch, num choices] full of sequence length
|
|
cls_shape = shape_list(cls_index)
|
|
if len(cls_shape) <= len(hidden_shape) - 2:
|
|
cls_index = cls_index[..., tf.newaxis]
|
|
# else:
|
|
# cls_index = cls_index[..., tf.newaxis]
|
|
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
|
|
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
|
|
output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2)
|
|
output = tf.squeeze(output, axis=len(hidden_shape) - 2) # shape of output: (batch, num choices, hidden_size)
|
|
elif self.summary_type == 'attn':
|
|
raise NotImplementedError
|
|
|
|
if training and self.first_dropout is not None:
|
|
output = self.first_dropout(output)
|
|
|
|
if self.summary is not None:
|
|
output = self.summary(output)
|
|
|
|
if self.activation is not None:
|
|
output = self.activation(output)
|
|
|
|
if training and self.last_dropout is not None:
|
|
output = self.last_dropout(output)
|
|
|
|
return output
|
|
|
|
def shape_list(x):
|
|
"""Deal with dynamic shape in tensorflow cleanly."""
|
|
static = x.shape.as_list()
|
|
dynamic = tf.shape(x)
|
|
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
|
|
|
|
def get_initializer(initializer_range=0.02):
|
|
"""Creates a `tf.initializers.truncated_normal` with the given range.
|
|
Args:
|
|
initializer_range: float, initializer range for stddev.
|
|
Returns:
|
|
TruncatedNormal initializer with stddev = `initializer_range`.
|
|
"""
|
|
return tf.keras.initializers.TruncatedNormal(stddev=initializer_range)
|