atom-predict/msunet/test_pl_unequal.py

314 lines
9.4 KiB
Python
Executable File

import os
import json
import glob
import torch
from typing import List
torch.set_float32_matmul_precision('high')
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import KFold
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
from core.model import *
from core.data import *
from core.metrics import *
# constants
DATASETS = '0'
GPUS = 1
SIGMA = 3
BS = 32
RF = 0.9
NW = 8
WD = 1e-5
LR = 3e-4
DIM = 256
# SLIDE_DIM = 2048
# patch_size = 256
# roi_size = 128
patch_size = 512
roi_size = 256
EPOCHS = 1000
IN_CHANNELS = 3
SAVE_TOP_K = -1
EARLY_STOP = 5
EVERY_N_EPOCHS = 1
IMAGE_PATH = '../../data/linesv/patch_unet/'
# IMAGE_PATH2 = '/home/gao/下载/process/test/'
WEIGHTS = torch.FloatTensor([1. / 8, 1. / 4, 1. / 2, 1.])
# pytorch lightning module
class FCRN(pl.LightningModule):
def __init__(self, in_channels):
super().__init__()
self.fcrn = C_FCRN_Aux(in_channels)
self.loss = MyLoss(WEIGHTS)
def forward(self, x):
out = self.fcrn(x)
return out
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=LR, weight_decay=WD)
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_dice'
}
def training_step(self, train_batch, batch_idx):
x, y = train_batch
pd = self.fcrn(x)
loss = self.loss(pd, y)
train_iou = iou(pd, y)
train_dice = dice(pd, y)
self.log('train_loss', loss)
self.log('train_iou', train_iou, on_epoch=True, prog_bar=True, logger=True)
self.log('train_dice', train_dice, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
pd = self.fcrn(x)
loss = self.loss(pd, y)
val_iou = iou(pd, y)
val_dice = dice(pd, y)
self.log('val_loss', loss)
self.log('val_iou', val_iou, on_epoch=True, prog_bar=True, logger=True)
self.log('val_dice', val_dice, on_epoch=True, prog_bar=True, logger=True)
def predict_step(self, batch, batch_idx):
x, lbl = batch
x = torch.chunk(x[0], chunks=4, dim=0)
pred = torch.concat([self.fcrn(item)[-1] for item in x])
return pred.squeeze(), lbl
def predict(model_path:str, test_img_list:List[str], save_path:str, batch:int =1)->str:
print('Test nums: ', len(test_img_list))
# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
test_dataset = MyDatasetSlide_test_uneuqal(test_img_list, ps=patch_size, roi=roi_size)
# train_loader = DataLoader(
# dataset = train_dataset,
# batch_size = BS,
# num_workers = NW,
# drop_last = True,
# sampler = WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
# )
#
# valid_loader = DataLoader(
# dataset = valid_dataset,
# batch_size = BS,
# shuffle = False,
# num_workers = NW,
# )
test_loader = DataLoader(
dataset=test_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
)
model = FCRN(IN_CHANNELS)
logger = TensorBoardLogger(
name=DATASETS,
save_dir='logs',
)
checkpoint_callback = ModelCheckpoint(
every_n_epochs=EVERY_N_EPOCHS,
save_top_k=SAVE_TOP_K,
monitor='val_dice',
mode='max',
save_last=True,
filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
)
earlystop_callback = EarlyStopping(
monitor="val_dice",
mode="max",
min_delta=0.00,
patience=EARLY_STOP,
)
# training
trainer = pl.Trainer(
accelerator='gpu',
devices=GPUS,
max_epochs=EPOCHS,
logger=logger,
callbacks=[checkpoint_callback, earlystop_callback],
)
# trainer.fit(
# model,
# train_loader,
# valid_loader
# )
# inference
predictions = trainer.predict(
model=model,
dataloaders=test_loader,
ckpt_path=model_path
)
# pre = = spliter.recover(patches, item[1].shape[1], item[1].shape[2], ps=patch_size,roi=roi_size)
preds = np.concatenate([test_dataset.spliter.recover(item[0], item[1].shape[1], item[1].shape[2], ps=patch_size,
roi=roi_size)[np.newaxis, :, :] for item in
predictions]).tolist()
labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
results = {
'img_path': test_img_list,
'pred': preds,
'label': labels,
}
# predict_dir= './result/predict/'
results_json = json.dumps(results)
filename = os.path.join(save_path, f'test_{batch}.json')
with open(filename, 'w+') as f:
f.write(results_json)
return filename
# main
#
# if __name__ == '__main__':
# # train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
# # valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
# test_img_list = np.array(glob.glob('./result/pre_process/img/*.png')).tolist()
# # test_img_list = np.array(glob.glob('/home/gao/mouclear/detect_40/40/*.jpg')).tolist()
#
# # print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list), len(test_img_list)))
# print('Test nums: ', len(test_img_list))
#
# # train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
# # valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
# test_dataset = MyDatasetSlide_test_uneuqal(test_img_list, ps=patch_size, roi=roi_size)
#
# # train_loader = DataLoader(
# # dataset = train_dataset,
# # batch_size = BS,
# # num_workers = NW,
# # drop_last = True,
# # sampler = WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
# # )
# #
# # valid_loader = DataLoader(
# # dataset = valid_dataset,
# # batch_size = BS,
# # shuffle = False,
# # num_workers = NW,
# # )
#
# test_loader = DataLoader(
# dataset=test_dataset,
# batch_size=1,
# shuffle=False,
# num_workers=1,
# )
#
# model = FCRN(IN_CHANNELS)
#
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # model = model.to(device)
# # checkpoint = torch.load('/home/gao/mouclear/cc/code/msunet/logs/0/version_0/checkpoints/last.ckpt', map_location=device)
# #
# # state_dict = checkpoint['state_dict']
#
# # 创建状态字典的副本
# # new_state_dict = state_dict.copy()
#
# # 重命名键
# # for key in new_state_dict:
# # if key.startswith('model.model'):
# # # 这里使用 pop 来删除旧的键值对,然后使用新的键添加它
# # new_key = key.replace('model.model', 'model')
# # state_dict[new_key] = state_dict.pop(key)
#
# # 现在 new_state_dict 包含了更新后的键,可以安全地加载到模型中
# # model.load_state_dict(state_dict) # 使用 strict=False 以忽略不匹配的键
#
# logger = TensorBoardLogger(
# name=DATASETS,
# save_dir='logs',
# )
#
# checkpoint_callback = ModelCheckpoint(
# every_n_epochs=EVERY_N_EPOCHS,
# save_top_k=SAVE_TOP_K,
# monitor='val_dice',
# mode='max',
# save_last=True,
# filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
# )
#
# earlystop_callback = EarlyStopping(
# monitor="val_dice",
# mode="max",
# min_delta=0.00,
# patience=EARLY_STOP,
# )
#
# # training
# trainer = pl.Trainer(
# accelerator='gpu',
# devices=GPUS,
# max_epochs=EPOCHS,
# logger=logger,
# callbacks=[checkpoint_callback, earlystop_callback],
# )
#
# # trainer.fit(
# # model,
# # train_loader,
# # valid_loader
# # )
#
# # inference
# predictions = trainer.predict(
# model=model,
# dataloaders=test_loader,
# ckpt_path='./result/train/last.ckpt'
# )
#
# # pre = = spliter.recover(patches, item[1].shape[1], item[1].shape[2], ps=patch_size,roi=roi_size)
#
# preds = np.concatenate([test_dataset.spliter.recover(item[0], item[1].shape[1], item[1].shape[2], ps=patch_size,
# roi=roi_size)[np.newaxis, :, :] for item in
# predictions]).tolist()
# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
#
# results = {
# 'img_path': test_img_list,
# 'pred': preds,
# 'label': labels,
# }
#
# predict_dir= './result/predict/'
# results_json = json.dumps(results)
# with open(os.path.join(predict_dir, 'test.json'), 'w+') as f:
# f.write(results_json)