From f7375b2c02641122dc772e5207e1691092b8a9de Mon Sep 17 00:00:00 2001 From: lujinqing <1738446749@qq.com> Date: Sat, 30 Oct 2021 11:53:36 +0800 Subject: [PATCH] ADD file via upload --- densnet100-cifar10.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 densnet100-cifar10.py diff --git a/densnet100-cifar10.py b/densnet100-cifar10.py new file mode 100644 index 0000000..ac9e48b --- /dev/null +++ b/densnet100-cifar10.py @@ -0,0 +1,31 @@ +import tensorlayer as tl +from Densnet import * +import numpy as np +from tensorlayer.dataflow import Dataset + +X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3)) +def generator_train(): + inputs = X_train + targets = y_train + if len(inputs) != len(targets): + raise AssertionError("The length of inputs and targets should be equal") + for _input, _target in zip(inputs, targets): + yield (_input, np.array(_target)) + +model = dnet100() +n_epoch = 50 +batch_size = 128 +print_freq = 2 +shuffle_buffer_size = 128 +train_weights = model.trainable_weights +optimizer = tl.optimizers.Momentum(0.05, 0.9) +train_ds = tl.dataflow.FromGenerator( + generator_train, output_types=(tl.float32, tl.int32) , column_names=['data', 'label'] +) +train_ds = tl.dataflow.Shuffle(train_ds,shuffle_buffer_size) +train_ds = tl.dataflow.Batch(train_ds,batch_size) + + +optimizer = tl.optimizers.Momentum(0.05, 0.9) +model = tl.models.Model(network=model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer) +model.train(n_epoch=500, train_dataset=train_ds, print_freq=2) \ No newline at end of file