134 lines
3.7 KiB
Python
134 lines
3.7 KiB
Python
import os
|
|
import json
|
|
import glob
|
|
import torch
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
import numpy as np
|
|
from sklearn.utils import shuffle
|
|
from sklearn.model_selection import KFold
|
|
from torch import nn
|
|
from torch.utils.data import DataLoader
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
import pytorch_lightning as pl
|
|
|
|
from core.model import *
|
|
from core.data import *
|
|
from core.metrics import *
|
|
|
|
# constants
|
|
|
|
DATASETS = 'TEM'
|
|
GPUS = 1
|
|
SIGMA = 3
|
|
BS = 16
|
|
RF = 0.9
|
|
NW = 4
|
|
WD = 1e-5
|
|
LR = 3e-4
|
|
DIM = 256
|
|
EPOCHS = 1000
|
|
IN_CHANNELS = 3
|
|
SAVE_TOP_K = 1
|
|
TTA = 8
|
|
EARLY_STOP = 20
|
|
EVERY_N_EPOCHS = 1
|
|
LOG_EVERY_N_STEPS = 1
|
|
WEIGHTS = torch.FloatTensor([1./64, 1./16., 1./4, 1.])
|
|
IMAGE_PATH = '/home/andrewtal/Workspace/metrials/v15_Final/data/infer/patch'
|
|
|
|
# pytorch lightning module
|
|
|
|
class FCRN(pl.LightningModule):
|
|
def __init__(self, in_channels):
|
|
super().__init__()
|
|
self.fcrn = C_FCRN_Aux(in_channels)
|
|
self.loss = MyLoss(WEIGHTS)
|
|
|
|
def forward(self, x):
|
|
out = self.fcrn(x)
|
|
return out
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.Adam(self.parameters(), lr=LR, weight_decay=WD)
|
|
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='min', patience=4, min_lr=0, verbose=True)
|
|
# patience=10
|
|
return {
|
|
'optimizer': optimizer,
|
|
'lr_scheduler': scheduler,
|
|
'monitor': 'val_mce'
|
|
}
|
|
|
|
def training_step(self, train_batch, batch_idx):
|
|
x, d1, d2, d3, d4 = train_batch
|
|
y = [d1, d2, d3, d4]
|
|
pd = self.fcrn(x)
|
|
loss = self.loss(pd, y)
|
|
train_mce = mce(pd, y)
|
|
self.log('train_loss', loss)
|
|
self.log('train_mce', train_mce, on_epoch=True, prog_bar=True, logger=True)
|
|
|
|
return loss
|
|
|
|
def validation_step(self, val_batch, batch_idx):
|
|
x, d1, d2, d3, d4 = val_batch
|
|
y = [d1, d2, d3, d4]
|
|
pd = self.fcrn(x)
|
|
loss = self.loss(pd, y)
|
|
val_mce = mce(pd, y)
|
|
self.log('val_loss', loss)
|
|
self.log('val_mce', val_mce, on_epoch=True, prog_bar=True, logger=True)
|
|
|
|
def predict_step(self, batch, batch_idx):
|
|
x, _, _, _, d4 = batch
|
|
pred = self.fcrn(x)[-1]
|
|
|
|
return pred, d4
|
|
|
|
# main
|
|
|
|
if __name__ == '__main__':
|
|
infer_img_list = np.array(glob.glob('{}/*.jpg'.format(IMAGE_PATH))).tolist()
|
|
infer_dataset = MyDataset(infer_img_list, dim=DIM, sigma=SIGMA, data_type='test')
|
|
|
|
infer_loader = DataLoader(
|
|
dataset = infer_dataset,
|
|
batch_size = BS,
|
|
shuffle = False,
|
|
num_workers = NW,
|
|
)
|
|
|
|
model = FCRN(IN_CHANNELS)
|
|
|
|
# training
|
|
trainer = pl.Trainer(
|
|
accelerator = 'gpu',
|
|
devices = GPUS,
|
|
max_epochs = EPOCHS,
|
|
logger = False,
|
|
)
|
|
|
|
# inference
|
|
for i in range(TTA):
|
|
predictions = trainer.predict(
|
|
model = model,
|
|
dataloaders = infer_loader,
|
|
ckpt_path = '/home/andrewtal/Workspace/metrials/v15_Final/code/fcrn/lightning_logs/TEM/version_0/checkpoints/best.ckpt'
|
|
)
|
|
|
|
preds = torch.squeeze(torch.concat([item[0] for item in predictions])).numpy().tolist()
|
|
labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
|
|
|
|
results ={
|
|
'img_path': infer_img_list,
|
|
'pred': preds,
|
|
'label': labels,
|
|
}
|
|
|
|
results_json = json.dumps(results)
|
|
with open(os.path.join(trainer.log_dir, 'infer_patch_{}.json'.format(str(i))), 'w+') as f:
|
|
f.write(results_json)
|