tensorlayer3/densnet121-imagenet.py

23 lines
883 B
Python

import tensorlayer as tl
from Densnet import *
import numpy as np
from tensorlayer.dataflow import Dataset, Dataloader
import os
from PIL import Image
import tensorflow as tf
import tensorflow_datasets as tfds
#imagenet 需要手动下载,注意代码提示, 参考链接:https://www.tensorflow.org/datasets/catalog/imagenet2012
DS = tfds.load('imagenet2012', split='train', shuffle_files=True, as_supervised=True)
# ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
DS= DS.repeat().shuffle(1024).batch(32)
DS = DS.prefetch(tf.data.experimental.AUTOTUNE)
Model = densnet121()
n_epoch = 50
print_freq = 2
train_weights = Model.trainable_weights
optimizer = tl.optimizers.Momentum(0.05, 0.9)
Model = tl.models.Model(network=Model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
Model.train(n_epoch=500, train_dataset=ds, print_freq=2)