tensorlayer3/tensorlayer/layers/core/common.py

183 lines
7.2 KiB
Python

#! /usr/bin/python
# -*- coding: utf-8 -*-
import os
import tensorlayer as tl
from tensorlayer.files import utils
from tensorlayer import logging
_act_dict = {
"relu": tl.ops.ReLU,
"relu6": tl.ops.ReLU6,
"leaky_relu": tl.ops.LeakyReLU,
"lrelu": tl.ops.LeakyReLU,
"softplus": tl.ops.Softplus,
"tanh": tl.ops.Tanh,
"sigmoid": tl.ops.Sigmoid,
"softmax": tl.ops.Softmax
}
def str2act(act):
if len(act) > 5 and act[0:5] == "lrelu":
try:
alpha = float(act[5:])
return tl.ops.LeakyReLU(alpha=alpha)
except Exception as e:
raise Exception("{} can not be parsed as a float".format(act[5:]))
if len(act) > 10 and act[0:10] == "leaky_relu":
try:
alpha = float(act[10:])
return tl.ops.LeakyReLU(alpha=alpha)
except Exception as e:
raise Exception("{} can not be parsed as a float".format(act[10:]))
if act not in _act_dict.keys():
raise Exception("Unsupported act: {}".format(act))
return _act_dict[act]
def _save_weights(net, file_path, format=None):
"""Input file_path, save model weights into a file of given format.
Use net.load_weights() to restore.
Parameters
----------
file_path : str
Filename to which the model weights will be saved.
format : str or None
Saved file format.
Value should be None, 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
1) If this is set to None, then the postfix of file_path will be used to decide saved format.
If the postfix is not in ['h5', 'hdf5', 'npz', 'ckpt'], then file will be saved in hdf5 format by default.
2) 'hdf5' will save model weights name in a list and each layer has its weights stored in a group of
the hdf5 file.
3) 'npz' will save model weights sequentially into a npz file.
4) 'npz_dict' will save model weights along with its name as a dict into a npz file.
5) 'ckpt' will save model weights into a tensorflow ckpt file.
Default None.
Examples
--------
1) Save model weights in hdf5 format by default.
>>> net = vgg16()
>>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
>>> metric = tl.metric.Accuracy()
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
>>> model.save_weights('./model.h5')
...
>>> model.load_weights('./model.h5')
2) Save model weights in npz/npz_dict format
>>> model.save_weights('./model.npz')
>>> model.save_weights('./model.npz', format='npz_dict')
"""
if net.all_weights is None or len(net.all_weights) == 0:
logging.warning("Model contains no weights or layers haven't been built, nothing will be saved")
return
if format is None:
postfix = file_path.split('.')[-1]
if postfix in ['h5', 'hdf5', 'npz', 'ckpt']:
format = postfix
else:
format = 'hdf5'
if format == 'hdf5' or format == 'h5':
raise NotImplementedError("hdf5 load/save is not supported now.")
# utils.save_weights_to_hdf5(file_path, net)
elif format == 'npz':
utils.save_npz(net.all_weights, file_path)
elif format == 'npz_dict':
utils.save_npz_dict(net.all_weights, file_path)
elif format == 'ckpt':
# TODO: enable this when tf save ckpt is enabled
raise NotImplementedError("ckpt load/save is not supported now.")
else:
raise ValueError(
"Save format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'."
"Other format is not supported now."
)
def _load_weights(net, file_path, format=None, in_order=True, skip=False):
"""Load model weights from a given file, which should be previously saved by net.save_weights().
Parameters
----------
file_path : str
Filename from which the model weights will be loaded.
format : str or None
If not specified (None), the postfix of the file_path will be used to decide its format. If specified,
value should be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. Other format is not supported now.
In addition, it should be the same format when you saved the file using net.save_weights().
Default is None.
in_order : bool
Allow loading weights into model in a sequential way or by name. Only useful when 'format' is 'hdf5'.
If 'in_order' is True, weights from the file will be loaded into model in a sequential way.
If 'in_order' is False, weights from the file will be loaded into model by matching the name
with the weights of the model, particularly useful when trying to restore model in eager(graph) mode from
a weights file which is saved in graph(eager) mode.
Default is True.
skip : bool
Allow skipping weights whose name is mismatched between the file and model. Only useful when 'format' is
'hdf5' or 'npz_dict'. If 'skip' is True, 'in_order' argument will be ignored and those loaded weights
whose name is not found in model weights (net.all_weights) will be skipped. If 'skip' is False, error will
occur when mismatch is found.
Default is False.
Examples
--------
1) load model from a hdf5 file.
>>> net = vgg16()
>>> optimizer = tl.optimizers.Adam(learning_rate=0.001)
>>> metric = tl.metric.Accuracy()
>>> model = tl.models.Model(network=net, loss_fn=tl.cost.cross_entropy, optimizer=optimizer, metrics=metric)
>>> model.load_weights('./model_graph.h5', in_order=False, skip=True) # load weights by name, skipping mismatch
>>> model.load_weights('./model_eager.h5') # load sequentially
2) load model from a npz file
>>> model.load_weights('./model.npz')
3) load model from a npz file, which is saved as npz_dict previously
>>> model.load_weights('./model.npz', format='npz_dict')
Notes
-------
1) 'in_order' is only useful when 'format' is 'hdf5'. If you are trying to load a weights file which is
saved in a different mode, it is recommended to set 'in_order' be True.
2) 'skip' is useful when 'format' is 'hdf5' or 'npz_dict'. If 'skip' is True,
'in_order' argument will be ignored.
"""
if not os.path.exists(file_path):
raise FileNotFoundError("file {} doesn't exist.".format(file_path))
if format is None:
format = file_path.split('.')[-1]
if format == 'hdf5' or format == 'h5':
raise NotImplementedError("hdf5 load/save is not supported now.")
# if skip ==True or in_order == False:
# # load by weights name
# utils.load_hdf5_to_weights(file_path, net, skip)
# else:
# # load in order
# utils.load_hdf5_to_weights_in_order(file_path, net)
elif format == 'npz':
utils.load_and_assign_npz(file_path, net)
elif format == 'npz_dict':
utils.load_and_assign_npz_dict(file_path, net, skip)
elif format == 'ckpt':
# TODO: enable this when tf save ckpt is enabled
raise NotImplementedError("ckpt load/save is not supported now.")
else:
raise ValueError(
"File format must be 'hdf5', 'npz', 'npz_dict' or 'ckpt'. "
"Other format is not supported now."
)