训练整合

This commit is contained in:
somunslotus 2024-07-11 17:20:46 +08:00
parent 17f801f7ae
commit 32c4a249e0
4 changed files with 164 additions and 10 deletions

View File

@ -254,7 +254,7 @@ def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
report = Report(name="预测结果",title="原子定位结果", sections=[ post_process_img_section, ori_img_section, pre_process_img_section])
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section, pre_process_img_section])
report.save(output_dir)

View File

@ -49,10 +49,19 @@ def crop_slide(img_path, save_path, patch_size=256, step=128):
mask = np.zeros_like(img)
mask[points[:, 0], points[:, 1]] = labels
img_item_path = os.path.join(save_path, 'img')
if not os.path.exists(img_item_path):
os.makedirs(img_item_path, exist_ok=True)
lbl_item_path = os.path.join(save_path, 'lbl')
if not os.path.exists(lbl_item_path):
os.makedirs(lbl_item_path, exist_ok=True)
for i in range(0, h - patch_size + 1, step):
for j in range(0, w - patch_size + 1, step):
v_nums = np.sum(mask[i:i + patch_size, j:j + patch_size] > 1)
Image.fromarray(img[i:i + patch_size, j:j + patch_size]).save(
os.path.join(save_path, 'img', '{}_{}_{}_{}.png'.format(base_name, str(i), str(j), str(v_nums)))
)
@ -84,6 +93,15 @@ def process_slide(img_path, save_path):
mask = np.zeros_like(img)
mask[points[:, 0], points[:, 1]] = labels
img_item_path = os.path.join(save_path, 'img')
if not os.path.exists(img_item_path):
os.makedirs(img_item_path, exist_ok=True)
lbl_item_path = os.path.join(save_path, 'lbl')
if not os.path.exists(lbl_item_path):
os.makedirs(lbl_item_path, exist_ok=True)
Image.fromarray(img).save(
os.path.join(save_path, 'img', '{}.png'.format(base_name))
)

View File

@ -194,7 +194,7 @@ def predict(model_path:str, test_img_list:List[str], save_path:str)->str:
return filename
# main
#
# if __name__ == '__main__':
# # train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
# # valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()

View File

@ -16,11 +16,11 @@ from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
from pre import crop_slide, process_slide
from core.model import *
from core.data import *
from core.metrics import *
from typing import Dict
# constants
DATASETS = '0'
@ -93,12 +93,49 @@ class FCRN(pl.LightningModule):
return pred.squeeze(), lbl
# main
def pre_process(image_path: str, save_path: str)->Dict[str, str]:
# 数据集划分
subpath = ["train", "valid"]
steps = {
"train": 64,
"valid": 256,
if __name__ == '__main__':
train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist()
}
result = {}
for sub in subpath:
full_path = os.path.join(image_path, sub)
print("process path:", full_path)
if not os.path.exists(full_path):
print("f{} not exist".format(full_path))
raise Exception("f{} not exist".format(full_path))
save_full_path = os.path.join(save_path, sub)
if not os.path.exists(save_full_path):
os.makedirs(save_full_path, exist_ok=True)
for item in glob.glob(os.path.join(full_path, "*.jpg")):
crop_slide(item, save_full_path, 256, steps[sub])
result[sub] = save_full_path
test_imag_list = np.array(glob.glob(os.path.join(image_path, 'test/*.jpg'))).tolist()
save_test_path = os.path.join(save_path, 'test')
if not os.path.exists(save_test_path):
os.makedirs(save_test_path, exist_ok=True)
for img in test_imag_list:
process_slide(img, save_test_path)
result["test"] = save_test_path
return result
def train(image_path: str, save_path: str)->str:
if not os.path.exists(save_path):
os.makedirs(save_path, exist_ok=True)
# 数据集划分
pre_result = pre_process(image_path, save_path)
train_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['train']))).tolist()
valid_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['valid']))).tolist()
test_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['test']))).tolist()
print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list),
len(test_img_list)))
@ -186,3 +223,102 @@ if __name__ == '__main__':
results_json = json.dumps(results)
with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
f.write(results_json)
# main
if __name__ == '__main__':
image_path = './train_and_test'
save_path = './train_result'
train(image_path, save_path)
# if __name__ == '__main__':
# train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
# valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
# test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist()
#
# print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list),
# len(test_img_list)))
#
# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
# test_dataset = MyDatasetSlide(test_img_list, dim=SLIDE_DIM)
#
# train_loader = DataLoader(
# dataset=train_dataset,
# batch_size=BS,
# num_workers=NW,
# drop_last=True,
# sampler=WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
# )
#
# valid_loader = DataLoader(
# dataset=valid_dataset,
# batch_size=BS,
# shuffle=False,
# num_workers=NW,
# )
#
# test_loader = DataLoader(
# dataset=test_dataset,
# batch_size=1,
# shuffle=False,
# num_workers=1,
# )
#
# model = FCRN(IN_CHANNELS)
#
# logger = TensorBoardLogger(
# name=DATASETS,
# save_dir='logs',
# )
#
# checkpoint_callback = ModelCheckpoint(
# every_n_epochs=EVERY_N_EPOCHS,
# save_top_k=SAVE_TOP_K,
# monitor='val_dice',
# mode='max',
# save_last=True,
# filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
# )
#
# earlystop_callback = EarlyStopping(
# monitor="val_dice",
# mode="max",
# min_delta=0.00,
# patience=EARLY_STOP,
# )
#
# # training
# trainer = pl.Trainer(
# accelerator='gpu',
# devices=GPUS,
# max_epochs=EPOCHS,
# logger=logger,
# callbacks=[checkpoint_callback, earlystop_callback]
# )
#
# trainer.fit(
# model,
# train_loader,
# valid_loader
# )
#
# # inference
# predictions = trainer.predict(
# model=model,
# dataloaders=test_loader,
# ckpt_path='best'
# )
#
# preds = np.concatenate([test_dataset.spliter.recover(item[0])[np.newaxis, :, :] for item in predictions]).tolist()
# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
#
# results = {
# 'img_path': test_img_list,
# 'pred': preds,
# 'label': labels,
# }
#
# results_json = json.dumps(results)
# with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
# f.write(results_json)