tensorlayer3/densnet100-cifar10.py

31 lines
1.1 KiB
Python

import tensorlayer as tl
from Densnet import *
import numpy as np
from tensorlayer.dataflow import Dataset
X_trn, y_trn, X_te, y_te = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))
def generator_train():
inputs = X_trn
targets = y_trn
if len(inputs) != len(targets):
raise AssertionError("The length of inputs and targets should be equal")
for _input, _target in zip(inputs, targets):
yield (_input, np.array(_target))
Model = densnet100()
n_epoch = 50
batch_size = 128
print_freq = 2
shuffle_buffer_size = 128
train_weights = Model.trainable_weights
optimizer = tl.optimizers.Momentum(0.05, 0.9)
train_ds = tl.dataflow.FromGenerator(
generator_train, output_types=(tl.float32, tl.int32) , column_names=['data', 'label']
)
train_ds = tl.dataflow.Shuffle(train_ds,shuffle_buffer_size)
train_ds = tl.dataflow.Batch(train_ds,batch_size)
optimizer = tl.optimizers.Momentum(0.05, 0.9)
Model = tl.models.Model(network=Model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
Model.train(n_epoch=500, train_dataset=train_ds, print_freq=2)