forked from TensorLayer/tensorlayer3
23 lines
883 B
Python
23 lines
883 B
Python
import tensorlayer as tl
|
|
from Densnet import *
|
|
import numpy as np
|
|
from tensorlayer.dataflow import Dataset, Dataloader
|
|
import os
|
|
from PIL import Image
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
|
|
#imagenet 需要手动下载,注意代码提示, 参考链接:https://www.tensorflow.org/datasets/catalog/imagenet2012
|
|
DS = tfds.load('imagenet2012', split='train', shuffle_files=True, as_supervised=True)
|
|
# ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
|
|
DS= DS.repeat().shuffle(1024).batch(32)
|
|
DS = DS.prefetch(tf.data.experimental.AUTOTUNE)
|
|
|
|
Model = densnet121()
|
|
n_epoch = 50
|
|
print_freq = 2
|
|
train_weights = Model.trainable_weights
|
|
optimizer = tl.optimizers.Momentum(0.05, 0.9)
|
|
Model = tl.models.Model(network=Model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
|
|
Model.train(n_epoch=500, train_dataset=ds, print_freq=2)
|