diff --git a/msunet/main.py b/msunet/main.py index 28296a8..a310c62 100644 --- a/msunet/main.py +++ b/msunet/main.py @@ -254,7 +254,7 @@ def generate_report(save_path: Dict[str, str], output_dir: str) -> None: post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements) - report = Report(name="预测结果",title="原子定位结果", sections=[ post_process_img_section, ori_img_section, pre_process_img_section]) + report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section, pre_process_img_section]) report.save(output_dir) diff --git a/msunet/pre.py b/msunet/pre.py index 4b4b65f..dccece3 100644 --- a/msunet/pre.py +++ b/msunet/pre.py @@ -49,10 +49,19 @@ def crop_slide(img_path, save_path, patch_size=256, step=128): mask = np.zeros_like(img) mask[points[:, 0], points[:, 1]] = labels + img_item_path = os.path.join(save_path, 'img') + if not os.path.exists(img_item_path): + os.makedirs(img_item_path, exist_ok=True) + + lbl_item_path = os.path.join(save_path, 'lbl') + if not os.path.exists(lbl_item_path): + os.makedirs(lbl_item_path, exist_ok=True) + for i in range(0, h - patch_size + 1, step): for j in range(0, w - patch_size + 1, step): v_nums = np.sum(mask[i:i + patch_size, j:j + patch_size] > 1) + Image.fromarray(img[i:i + patch_size, j:j + patch_size]).save( os.path.join(save_path, 'img', '{}_{}_{}_{}.png'.format(base_name, str(i), str(j), str(v_nums))) ) @@ -84,6 +93,15 @@ def process_slide(img_path, save_path): mask = np.zeros_like(img) mask[points[:, 0], points[:, 1]] = labels + img_item_path = os.path.join(save_path, 'img') + if not os.path.exists(img_item_path): + os.makedirs(img_item_path, exist_ok=True) + + + lbl_item_path = os.path.join(save_path, 'lbl') + if not os.path.exists(lbl_item_path): + os.makedirs(lbl_item_path, exist_ok=True) + Image.fromarray(img).save( os.path.join(save_path, 'img', '{}.png'.format(base_name)) ) diff --git a/msunet/test_pl_unequal.py b/msunet/test_pl_unequal.py index eaa1a21..55a3722 100644 --- a/msunet/test_pl_unequal.py +++ b/msunet/test_pl_unequal.py @@ -194,7 +194,7 @@ def predict(model_path:str, test_img_list:List[str], save_path:str)->str: return filename # main - +# # if __name__ == '__main__': # # train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist() # # valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist() diff --git a/msunet/train_pl.py b/msunet/train_pl.py index 74414c3..0fb1f10 100644 --- a/msunet/train_pl.py +++ b/msunet/train_pl.py @@ -16,11 +16,11 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks.early_stopping import EarlyStopping import pytorch_lightning as pl - +from pre import crop_slide, process_slide from core.model import * from core.data import * from core.metrics import * - +from typing import Dict # constants DATASETS = '0' @@ -93,12 +93,49 @@ class FCRN(pl.LightningModule): return pred.squeeze(), lbl -# main +def pre_process(image_path: str, save_path: str)->Dict[str, str]: + # 数据集划分 + subpath = ["train", "valid"] + steps = { + "train": 64, + "valid": 256, -if __name__ == '__main__': - train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist() - valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist() - test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist() + } + result = {} + for sub in subpath: + full_path = os.path.join(image_path, sub) + print("process path:", full_path) + if not os.path.exists(full_path): + print("f{} not exist".format(full_path)) + raise Exception("f{} not exist".format(full_path)) + + save_full_path = os.path.join(save_path, sub) + if not os.path.exists(save_full_path): + os.makedirs(save_full_path, exist_ok=True) + + for item in glob.glob(os.path.join(full_path, "*.jpg")): + crop_slide(item, save_full_path, 256, steps[sub]) + result[sub] = save_full_path + + test_imag_list = np.array(glob.glob(os.path.join(image_path, 'test/*.jpg'))).tolist() + save_test_path = os.path.join(save_path, 'test') + if not os.path.exists(save_test_path): + os.makedirs(save_test_path, exist_ok=True) + + for img in test_imag_list: + process_slide(img, save_test_path) + + result["test"] = save_test_path + return result + +def train(image_path: str, save_path: str)->str: + if not os.path.exists(save_path): + os.makedirs(save_path, exist_ok=True) + # 数据集划分 + pre_result = pre_process(image_path, save_path) + train_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['train']))).tolist() + valid_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['valid']))).tolist() + test_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['test']))).tolist() print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list), len(test_img_list))) @@ -185,4 +222,103 @@ if __name__ == '__main__': results_json = json.dumps(results) with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f: - f.write(results_json) \ No newline at end of file + f.write(results_json) + +# main + +if __name__ == '__main__': + image_path = './train_and_test' + save_path = './train_result' + + train(image_path, save_path) +# if __name__ == '__main__': +# train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist() +# valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist() +# test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist() +# +# print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list), +# len(test_img_list))) +# +# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train') +# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid') +# test_dataset = MyDatasetSlide(test_img_list, dim=SLIDE_DIM) +# +# train_loader = DataLoader( +# dataset=train_dataset, +# batch_size=BS, +# num_workers=NW, +# drop_last=True, +# sampler=WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset)) +# ) +# +# valid_loader = DataLoader( +# dataset=valid_dataset, +# batch_size=BS, +# shuffle=False, +# num_workers=NW, +# ) +# +# test_loader = DataLoader( +# dataset=test_dataset, +# batch_size=1, +# shuffle=False, +# num_workers=1, +# ) +# +# model = FCRN(IN_CHANNELS) +# +# logger = TensorBoardLogger( +# name=DATASETS, +# save_dir='logs', +# ) +# +# checkpoint_callback = ModelCheckpoint( +# every_n_epochs=EVERY_N_EPOCHS, +# save_top_k=SAVE_TOP_K, +# monitor='val_dice', +# mode='max', +# save_last=True, +# filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}' +# ) +# +# earlystop_callback = EarlyStopping( +# monitor="val_dice", +# mode="max", +# min_delta=0.00, +# patience=EARLY_STOP, +# ) +# +# # training +# trainer = pl.Trainer( +# accelerator='gpu', +# devices=GPUS, +# max_epochs=EPOCHS, +# logger=logger, +# callbacks=[checkpoint_callback, earlystop_callback] +# ) +# +# trainer.fit( +# model, +# train_loader, +# valid_loader +# ) +# +# # inference +# predictions = trainer.predict( +# model=model, +# dataloaders=test_loader, +# ckpt_path='best' +# ) +# +# preds = np.concatenate([test_dataset.spliter.recover(item[0])[np.newaxis, :, :] for item in predictions]).tolist() +# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist() +# +# results = { +# 'img_path': test_img_list, +# 'pred': preds, +# 'label': labels, +# } +# +# results_json = json.dumps(results) +# with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f: +# f.write(results_json) \ No newline at end of file