forked from TensorLayer/tensorlayer3
ADD file via upload
This commit is contained in:
parent
a1fc07da63
commit
f7375b2c02
|
@ -0,0 +1,31 @@
|
|||
import tensorlayer as tl
|
||||
from Densnet import *
|
||||
import numpy as np
|
||||
from tensorlayer.dataflow import Dataset
|
||||
|
||||
X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3))
|
||||
def generator_train():
|
||||
inputs = X_train
|
||||
targets = y_train
|
||||
if len(inputs) != len(targets):
|
||||
raise AssertionError("The length of inputs and targets should be equal")
|
||||
for _input, _target in zip(inputs, targets):
|
||||
yield (_input, np.array(_target))
|
||||
|
||||
model = dnet100()
|
||||
n_epoch = 50
|
||||
batch_size = 128
|
||||
print_freq = 2
|
||||
shuffle_buffer_size = 128
|
||||
train_weights = model.trainable_weights
|
||||
optimizer = tl.optimizers.Momentum(0.05, 0.9)
|
||||
train_ds = tl.dataflow.FromGenerator(
|
||||
generator_train, output_types=(tl.float32, tl.int32) , column_names=['data', 'label']
|
||||
)
|
||||
train_ds = tl.dataflow.Shuffle(train_ds,shuffle_buffer_size)
|
||||
train_ds = tl.dataflow.Batch(train_ds,batch_size)
|
||||
|
||||
|
||||
optimizer = tl.optimizers.Momentum(0.05, 0.9)
|
||||
model = tl.models.Model(network=model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
|
||||
model.train(n_epoch=500, train_dataset=train_ds, print_freq=2)
|
Loading…
Reference in New Issue