atom-predict/msunet/train_pl.py

388 lines
12 KiB
Python
Executable File

import argparse
import os
import json
import glob
import shutil
import zipfile
import torch
torch.set_float32_matmul_precision('high')
import numpy as np
from sklearn.utils import shuffle
from sklearn.model_selection import KFold
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.sampler import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
from pre import crop_slide, process_slide
from core.model import *
from core.data import *
from core.metrics import *
from typing import Dict
from predict import predict_and_plot
# constants
DATASETS = '0'
GPUS = 1
SIGMA = 3
BS = 32
RF = 0.9
NW = 8
WD = 1e-5
LR = 3e-4
DIM = 256
SLIDE_DIM = 2048
EPOCHS = 1000
IN_CHANNELS = 3
SAVE_TOP_K = -1
EARLY_STOP = 5
EVERY_N_EPOCHS = 1
IMAGE_PATH = '/home/gao/mouclear/cc/data_new/msunet/train_and_test_only_10'
WEIGHTS = torch.FloatTensor([1. / 8, 1. / 4, 1. / 2, 1.])
# pytorch lightning module
class FCRN(pl.LightningModule):
def __init__(self, in_channels):
super().__init__()
self.fcrn = C_FCRN_Aux(in_channels)
self.loss = MyLoss(WEIGHTS)
def forward(self, x):
out = self.fcrn(x)
return out
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=LR, weight_decay=WD)
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_dice'
}
def training_step(self, train_batch, batch_idx):
x, y = train_batch
pd = self.fcrn(x)
loss = self.loss(pd, y)
train_iou = iou(pd, y)
train_dice = dice(pd, y)
self.log('train_loss', loss)
self.log('train_iou', train_iou, on_epoch=True, prog_bar=True, logger=True)
self.log('train_dice', train_dice, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, val_batch, batch_idx):
x, y = val_batch
pd = self.fcrn(x)
loss = self.loss(pd, y)
val_iou = iou(pd, y)
val_dice = dice(pd, y)
self.log('val_loss', loss)
self.log('val_iou', val_iou, on_epoch=True, prog_bar=True, logger=True)
self.log('val_dice', val_dice, on_epoch=True, prog_bar=True, logger=True)
def predict_step(self, batch, batch_idx):
x, lbl = batch
x = torch.chunk(x[0], chunks=4, dim=0)
pred = torch.concat([self.fcrn(item)[-1] for item in x])
return pred.squeeze(), lbl
def pre_process(image_path: str, save_path: str)->Dict[str, str]:
# 数据集划分
subpath = ["train", "valid"]
steps = {
"train": 64,
"valid": 256,
}
result = {}
for sub in subpath:
full_path = os.path.join(image_path, sub)
print("process path:", full_path)
if not os.path.exists(full_path):
print("f{} not exist".format(full_path))
raise Exception("f{} not exist".format(full_path))
save_full_path = os.path.join(save_path, sub)
if not os.path.exists(save_full_path):
os.makedirs(save_full_path, exist_ok=True)
for item in glob.glob(os.path.join(full_path, "*.jpg")):
crop_slide(item, save_full_path, 256, steps[sub])
result[sub] = save_full_path
# test_imag_list = np.array(glob.glob(os.path.join(image_path, 'test/*.jpg'))).tolist()
# save_test_path = os.path.join(save_path, 'test')
# if not os.path.exists(save_test_path):
# os.makedirs(save_test_path, exist_ok=True)
#
# for img in test_imag_list:
# process_slide(img, save_test_path)
#
# result["test"] = save_test_path
return result
def train(process_path: str, model_save_path: str)->str:
train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(process_path, 'train'))).tolist()
valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(process_path, 'valid'))).tolist()
print('Train nums: {}, Valid nums: {},'.format(len(train_img_list), len(valid_img_list)))
train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
train_loader = DataLoader(
dataset=train_dataset,
batch_size=BS,
num_workers=NW,
drop_last=True,
sampler=WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
)
valid_loader = DataLoader(
dataset=valid_dataset,
batch_size=BS,
shuffle=False,
num_workers=NW,
)
# test_loader = DataLoader(
# dataset=test_dataset,
# batch_size=1,
# shuffle=False,
# num_workers=1,
# )
model = FCRN(IN_CHANNELS)
logger = TensorBoardLogger(
name=DATASETS,
save_dir='logs',
)
checkpoint_callback = ModelCheckpoint(
every_n_epochs=EVERY_N_EPOCHS,
save_top_k=SAVE_TOP_K,
monitor='val_dice',
mode='max',
save_last=True,
filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
)
earlystop_callback = EarlyStopping(
monitor="val_dice",
mode="max",
min_delta=0.00,
patience=EARLY_STOP,
)
# training
trainer = pl.Trainer(
accelerator='gpu',
devices=GPUS,
max_epochs=EPOCHS,
logger=logger,
callbacks=[checkpoint_callback, earlystop_callback]
)
trainer.fit(
model,
train_loader,
valid_loader
)
# model_name = os.path.join(model_save_path, 'last.ckpt')
# trainer.save_checkpoint(model_name)
last_checkpoint_path = checkpoint_callback.last_model_path
if last_checkpoint_path:
model_name = os.path.join(model_save_path, 'last.ckpt')
# Copy the last checkpoint to the desired path
shutil.copy(last_checkpoint_path, model_name)
else:
print("No checkpoint was saved by the checkpoint callback.")
return model_name
# # inference
# predictions = trainer.predict(
# model=model,
# dataloaders=test_loader,
# ckpt_path='best'
# )
#
# preds = np.concatenate([test_dataset.spliter.recover(item[0])[np.newaxis, :, :] for item in predictions]).tolist()
# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
#
# results = {
# 'img_path': test_img_list,
# 'pred': preds,
# 'label': labels,
# }
#
# results_json = json.dumps(results)
# with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
# f.write(results_json)
# main
def create_save_path(base_path: str) -> Dict[str, str]:
# 定义需要创建的目录结构
paths = {
"train_original" : os.path.join(base_path, 'train_original'),
"train_pre_process": os.path.join(base_path, "train_pre_process"),
"model_save_path": os.path.join(base_path, "train_model_save_path"),
}
# 创建目录
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def arg_parse()->argparse.Namespace:
# # 增加参数一个是数据集的路径,另外一个是保存的路径
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--dataset', default="", help='path to test dataset')
parser.add_argument('--model_output', default="", help='path to model result path')
# parser.add_argument('--data_path', type=str, default='', help='path to dataset')
# parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = arg_parse()
#创建目录
save_path = create_save_path(args.model_output)
print("save path: ", save_path)
original_path = save_path['train_original']
train_pre_process_path = save_path['train_pre_process']
model_save_path = save_path['model_save_path']
#从dataset目录获取数据集名称
for file in os.listdir(args.dataset):
if file.endswith('.zip'):
zip_file_path = os.path.join(args.dataset, file)
# 解压数据集到original_path
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(original_path)
#image_path = './train_and_test'
pre_process(original_path, train_pre_process_path)
print("train start")
model_path = train(train_pre_process_path, model_save_path)
print("test start")
# print("test start")
# test_path= os.path.join(original_path, 'test')
# predict_and_plot(model_path, test_path, args.save_path, True)
# print("test end")
# if __name__ == '__main__':
# train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
# valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
# test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist()
#
# print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list),
# len(test_img_list)))
#
# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
# test_dataset = MyDatasetSlide(test_img_list, dim=SLIDE_DIM)
#
# train_loader = DataLoader(
# dataset=train_dataset,
# batch_size=BS,
# num_workers=NW,
# drop_last=True,
# sampler=WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
# )
#
# valid_loader = DataLoader(
# dataset=valid_dataset,
# batch_size=BS,
# shuffle=False,
# num_workers=NW,
# )
#
# test_loader = DataLoader(
# dataset=test_dataset,
# batch_size=1,
# shuffle=False,
# num_workers=1,
# )
#
# model = FCRN(IN_CHANNELS)
#
# logger = TensorBoardLogger(
# name=DATASETS,
# save_dir='logs',
# )
#
# checkpoint_callback = ModelCheckpoint(
# every_n_epochs=EVERY_N_EPOCHS,
# save_top_k=SAVE_TOP_K,
# monitor='val_dice',
# mode='max',
# save_last=True,
# filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
# )
#
# earlystop_callback = EarlyStopping(
# monitor="val_dice",
# mode="max",
# min_delta=0.00,
# patience=EARLY_STOP,
# )
#
# # training
# trainer = pl.Trainer(
# accelerator='gpu',
# devices=GPUS,
# max_epochs=EPOCHS,
# logger=logger,
# callbacks=[checkpoint_callback, earlystop_callback]
# )
#
# trainer.fit(
# model,
# train_loader,
# valid_loader
# )
#
# # inference
# predictions = trainer.predict(
# model=model,
# dataloaders=test_loader,
# ckpt_path='best'
# )
#
# preds = np.concatenate([test_dataset.spliter.recover(item[0])[np.newaxis, :, :] for item in predictions]).tolist()
# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
#
# results = {
# 'img_path': test_img_list,
# 'pred': preds,
# 'label': labels,
# }
#
# results_json = json.dumps(results)
# with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
# f.write(results_json)