tensorlayer3/densenet_imagenet.py

146 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
os.environ['TL_BACKEND'] = 'tensorflow'
import time
import multiprocessing
import tensorflow as tf
from tensorlayer.models import TrainOneStep
from tensorlayer.layers import Module
import tensorlayer as tl
from .densenet import builddensenet
def load_imagenet_dataset(shape=(-1, 256, 256, 3), plotable=False):
"""
此函数根据本地环境加载imagenet数据返回 X_train, y_train, X_test, y_test(训练集图像、标签,测试集图像、标签)
"""
return None
# enable debug logging
tl.logging.set_verbosity(tl.logging.DEBUG)
# prepare image_net data
X_train, y_train, X_test, y_test = load_imagenet_dataset(shape=(-1, 256, 256, 3), plotable=False)
# get the network
net = builddensenet("densenet-121")
# training settings
batch_size = 128
n_epoch = 500
learning_rate = 0.0001
print_freq = 5
n_step_epoch = int(len(y_train) / batch_size)
n_step = n_epoch * n_step_epoch
shuffle_buffer_size = 128
train_weights = net.trainable_weights
optimizer = tl.optimizers.Adam(learning_rate)
metrics = tl.metric.Accuracy()
def generator_train():
inputs = X_train
targets = y_train
if len(inputs) != len(targets):
raise AssertionError("The length of inputs and targets should be equal")
for _input, _target in zip(inputs, targets):
# yield _input.encode('utf-8'), _target.encode('utf-8')
yield _input, _target
def generator_test():
inputs = X_test
targets = y_test
if len(inputs) != len(targets):
raise AssertionError("The length of inputs and targets should be equal")
for _input, _target in zip(inputs, targets):
# yield _input.encode('utf-8'), _target.encode('utf-8')
yield _input, _target
def _map_fn_train(img, target):
# 1. Randomly crop a [height, width] section of the image.
img = tf.image.random_crop(img, [224, 224, 3])
# 2. Randomly flip the image horizontally.
img = tf.image.random_flip_left_right(img)
# 3. Randomly change brightness.
img = tf.image.random_brightness(img, max_delta=63)
# 4. Randomly change contrast.
img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
# 5. Subtract off the mean and divide by the variance of the pixels.
img = tf.image.per_image_standardization(img)
target = tf.reshape(target, ())
return img, target
def _map_fn_test(img, target):
# 1. Crop the central [height, width] of the image.
img = tf.image.resize_with_pad(img, 224, 224)
# 2. Subtract off the mean and divide by the variance of the pixels.
img = tf.image.per_image_standardization(img)
img = tf.reshape(img, (224, 224, 3))
target = tf.reshape(target, ())
return img, target
# dataset API and augmentation
train_ds = tf.data.Dataset.from_generator(
generator_train, output_types=(tf.float32, tf.int32)
) # , output_shapes=((24, 24, 3), (1)))
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
# train_ds = train_ds.repeat(n_epoch)
train_ds = train_ds.shuffle(shuffle_buffer_size)
train_ds = train_ds.prefetch(buffer_size=4096)
train_ds = train_ds.batch(batch_size)
# value = train_ds.make_one_shot_iterator().get_next()
test_ds = tf.data.Dataset.from_generator(
generator_test, output_types=(tf.float32, tf.int32)
) # , output_shapes=((24, 24, 3), (1)))
# test_ds = test_ds.shuffle(shuffle_buffer_size)
test_ds = test_ds.map(_map_fn_test, num_parallel_calls=multiprocessing.cpu_count())
# test_ds = test_ds.repeat(n_epoch)
test_ds = test_ds.prefetch(buffer_size=4096)
test_ds = test_ds.batch(batch_size)
# value_test = test_ds.make_one_shot_iterator().get_next()
class WithLoss(Module):
def __init__(self, net, loss_fn):
super(WithLoss, self).__init__()
self._net = net
self._loss_fn = loss_fn
def forward(self, data, label):
out = self._net(data)
loss = self._loss_fn(out, label)
return loss
net_with_loss = WithLoss(net, loss_fn=tl.cost.softmax_cross_entropy_with_logits)
net_with_train = TrainOneStep(net_with_loss, optimizer, train_weights)
for epoch in range(n_epoch):
start_time = time.time()
net.set_train()
train_loss, train_acc, n_iter = 0, 0, 0
for X_batch, y_batch in train_ds:
X_batch = tl.ops.convert_to_tensor(X_batch.numpy(), dtype=tl.float32)
y_batch = tl.ops.convert_to_tensor(y_batch.numpy(), dtype=tl.int64)
_loss_ce = net_with_train(X_batch, y_batch)
train_loss += _loss_ce
n_iter += 1
_logits = net(X_batch)
metrics.update(_logits, y_batch)
train_acc += metrics.result()
metrics.reset()
print("Epoch {} of {} took {}".format(epoch + 1, n_epoch, time.time() - start_time))
print(" train loss: {}".format(train_loss / n_iter))
print(" train acc: {}".format(train_acc / n_iter))