tensorlayer3/densnet121-imagenet.py

24 lines
766 B
Python

import tensorlayer as tl
from Densnet import *
import numpy as npy
from tensorlayer.dataflow import Dataset, Dataloader
import os
from PIL import Image
import tensorflow as tf
import tensorflow_datasets as tfds
ds = tfds.load('imagenet2012', split='train', shuffle_files=True, as_supervised=True)
# ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
ds= ds.repeat().shuffle(1024).batch(32)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
model = dnet121()
n_epoch = 50
print_freq = 2
train_weights = model.trainable_weights
optimizer = tl.optimizers.Momentum(0.05, 0.9)
model = tl.models.Model(network=model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
model.train(n_epoch=500, train_dataset=ds, print_freq=2)