ADD file via upload

This commit is contained in:
lujinqing 2021-10-30 11:53:36 +08:00
parent a1fc07da63
commit f7375b2c02
1 changed files with 31 additions and 0 deletions

31
densnet100-cifar10.py Normal file
View File

@ -0,0 +1,31 @@
import tensorlayer as tl
from Densnet import *
import numpy as np
from tensorlayer.dataflow import Dataset
X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))
def generator_train():
inputs = X_train
targets = y_train
if len(inputs) != len(targets):
raise AssertionError("The length of inputs and targets should be equal")
for _input, _target in zip(inputs, targets):
yield (_input, np.array(_target))
model = dnet100()
n_epoch = 50
batch_size = 128
print_freq = 2
shuffle_buffer_size = 128
train_weights = model.trainable_weights
optimizer = tl.optimizers.Momentum(0.05, 0.9)
train_ds = tl.dataflow.FromGenerator(
generator_train, output_types=(tl.float32, tl.int32) , column_names=['data', 'label']
)
train_ds = tl.dataflow.Shuffle(train_ds,shuffle_buffer_size)
train_ds = tl.dataflow.Batch(train_ds,batch_size)
optimizer = tl.optimizers.Momentum(0.05, 0.9)
model = tl.models.Model(network=model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
model.train(n_epoch=500, train_dataset=train_ds, print_freq=2)