forked from TensorLayer/tensorlayer3
24 lines
766 B
Python
24 lines
766 B
Python
import tensorlayer as tl
|
|
|
|
from Densnet import *
|
|
import numpy as npy
|
|
from tensorlayer.dataflow import Dataset, Dataloader
|
|
import os
|
|
from PIL import Image
|
|
import tensorflow as tf
|
|
import tensorflow_datasets as tfds
|
|
|
|
|
|
ds = tfds.load('imagenet2012', split='train', shuffle_files=True, as_supervised=True)
|
|
# ds = ds.shuffle(1024).batch(32).prefetch(tf.data.experimental.AUTOTUNE)
|
|
ds= ds.repeat().shuffle(1024).batch(32)
|
|
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
|
|
|
|
model = dnet121()
|
|
n_epoch = 50
|
|
print_freq = 2
|
|
train_weights = model.trainable_weights
|
|
optimizer = tl.optimizers.Momentum(0.05, 0.9)
|
|
model = tl.models.Model(network=model, loss_fn=tl.cost.softmax_cross_entropy_with_logits, optimizer=optimizer)
|
|
model.train(n_epoch=500, train_dataset=ds, print_freq=2)
|
|
|