训练整合
This commit is contained in:
parent
17f801f7ae
commit
32c4a249e0
|
@ -254,7 +254,7 @@ def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
|
|||
|
||||
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
|
||||
|
||||
report = Report(name="预测结果",title="原子定位结果", sections=[ post_process_img_section, ori_img_section, pre_process_img_section])
|
||||
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section, pre_process_img_section])
|
||||
report.save(output_dir)
|
||||
|
||||
|
||||
|
|
|
@ -49,10 +49,19 @@ def crop_slide(img_path, save_path, patch_size=256, step=128):
|
|||
mask = np.zeros_like(img)
|
||||
mask[points[:, 0], points[:, 1]] = labels
|
||||
|
||||
img_item_path = os.path.join(save_path, 'img')
|
||||
if not os.path.exists(img_item_path):
|
||||
os.makedirs(img_item_path, exist_ok=True)
|
||||
|
||||
lbl_item_path = os.path.join(save_path, 'lbl')
|
||||
if not os.path.exists(lbl_item_path):
|
||||
os.makedirs(lbl_item_path, exist_ok=True)
|
||||
|
||||
for i in range(0, h - patch_size + 1, step):
|
||||
for j in range(0, w - patch_size + 1, step):
|
||||
v_nums = np.sum(mask[i:i + patch_size, j:j + patch_size] > 1)
|
||||
|
||||
|
||||
Image.fromarray(img[i:i + patch_size, j:j + patch_size]).save(
|
||||
os.path.join(save_path, 'img', '{}_{}_{}_{}.png'.format(base_name, str(i), str(j), str(v_nums)))
|
||||
)
|
||||
|
@ -84,6 +93,15 @@ def process_slide(img_path, save_path):
|
|||
mask = np.zeros_like(img)
|
||||
mask[points[:, 0], points[:, 1]] = labels
|
||||
|
||||
img_item_path = os.path.join(save_path, 'img')
|
||||
if not os.path.exists(img_item_path):
|
||||
os.makedirs(img_item_path, exist_ok=True)
|
||||
|
||||
|
||||
lbl_item_path = os.path.join(save_path, 'lbl')
|
||||
if not os.path.exists(lbl_item_path):
|
||||
os.makedirs(lbl_item_path, exist_ok=True)
|
||||
|
||||
Image.fromarray(img).save(
|
||||
os.path.join(save_path, 'img', '{}.png'.format(base_name))
|
||||
)
|
||||
|
|
|
@ -194,7 +194,7 @@ def predict(model_path:str, test_img_list:List[str], save_path:str)->str:
|
|||
|
||||
return filename
|
||||
# main
|
||||
|
||||
#
|
||||
# if __name__ == '__main__':
|
||||
# # train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
|
||||
# # valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
|
||||
|
|
|
@ -16,11 +16,11 @@ from pytorch_lightning.loggers import TensorBoardLogger
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
import pytorch_lightning as pl
|
||||
|
||||
from pre import crop_slide, process_slide
|
||||
from core.model import *
|
||||
from core.data import *
|
||||
from core.metrics import *
|
||||
|
||||
from typing import Dict
|
||||
# constants
|
||||
|
||||
DATASETS = '0'
|
||||
|
@ -93,12 +93,49 @@ class FCRN(pl.LightningModule):
|
|||
return pred.squeeze(), lbl
|
||||
|
||||
|
||||
# main
|
||||
def pre_process(image_path: str, save_path: str)->Dict[str, str]:
|
||||
# 数据集划分
|
||||
subpath = ["train", "valid"]
|
||||
steps = {
|
||||
"train": 64,
|
||||
"valid": 256,
|
||||
|
||||
if __name__ == '__main__':
|
||||
train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
|
||||
valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
|
||||
test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist()
|
||||
}
|
||||
result = {}
|
||||
for sub in subpath:
|
||||
full_path = os.path.join(image_path, sub)
|
||||
print("process path:", full_path)
|
||||
if not os.path.exists(full_path):
|
||||
print("f{} not exist".format(full_path))
|
||||
raise Exception("f{} not exist".format(full_path))
|
||||
|
||||
save_full_path = os.path.join(save_path, sub)
|
||||
if not os.path.exists(save_full_path):
|
||||
os.makedirs(save_full_path, exist_ok=True)
|
||||
|
||||
for item in glob.glob(os.path.join(full_path, "*.jpg")):
|
||||
crop_slide(item, save_full_path, 256, steps[sub])
|
||||
result[sub] = save_full_path
|
||||
|
||||
test_imag_list = np.array(glob.glob(os.path.join(image_path, 'test/*.jpg'))).tolist()
|
||||
save_test_path = os.path.join(save_path, 'test')
|
||||
if not os.path.exists(save_test_path):
|
||||
os.makedirs(save_test_path, exist_ok=True)
|
||||
|
||||
for img in test_imag_list:
|
||||
process_slide(img, save_test_path)
|
||||
|
||||
result["test"] = save_test_path
|
||||
return result
|
||||
|
||||
def train(image_path: str, save_path: str)->str:
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
# 数据集划分
|
||||
pre_result = pre_process(image_path, save_path)
|
||||
train_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['train']))).tolist()
|
||||
valid_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['valid']))).tolist()
|
||||
test_img_list = np.array(glob.glob('{}/img/*.png'.format(pre_result['test']))).tolist()
|
||||
|
||||
print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list),
|
||||
len(test_img_list)))
|
||||
|
@ -186,3 +223,102 @@ if __name__ == '__main__':
|
|||
results_json = json.dumps(results)
|
||||
with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
|
||||
f.write(results_json)
|
||||
|
||||
# main
|
||||
|
||||
if __name__ == '__main__':
|
||||
image_path = './train_and_test'
|
||||
save_path = './train_result'
|
||||
|
||||
train(image_path, save_path)
|
||||
# if __name__ == '__main__':
|
||||
# train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
|
||||
# valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
|
||||
# test_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'test'))).tolist()
|
||||
#
|
||||
# print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list),
|
||||
# len(test_img_list)))
|
||||
#
|
||||
# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
|
||||
# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
|
||||
# test_dataset = MyDatasetSlide(test_img_list, dim=SLIDE_DIM)
|
||||
#
|
||||
# train_loader = DataLoader(
|
||||
# dataset=train_dataset,
|
||||
# batch_size=BS,
|
||||
# num_workers=NW,
|
||||
# drop_last=True,
|
||||
# sampler=WeightedRandomSampler(train_dataset.sample_weights, len(train_dataset))
|
||||
# )
|
||||
#
|
||||
# valid_loader = DataLoader(
|
||||
# dataset=valid_dataset,
|
||||
# batch_size=BS,
|
||||
# shuffle=False,
|
||||
# num_workers=NW,
|
||||
# )
|
||||
#
|
||||
# test_loader = DataLoader(
|
||||
# dataset=test_dataset,
|
||||
# batch_size=1,
|
||||
# shuffle=False,
|
||||
# num_workers=1,
|
||||
# )
|
||||
#
|
||||
# model = FCRN(IN_CHANNELS)
|
||||
#
|
||||
# logger = TensorBoardLogger(
|
||||
# name=DATASETS,
|
||||
# save_dir='logs',
|
||||
# )
|
||||
#
|
||||
# checkpoint_callback = ModelCheckpoint(
|
||||
# every_n_epochs=EVERY_N_EPOCHS,
|
||||
# save_top_k=SAVE_TOP_K,
|
||||
# monitor='val_dice',
|
||||
# mode='max',
|
||||
# save_last=True,
|
||||
# filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
|
||||
# )
|
||||
#
|
||||
# earlystop_callback = EarlyStopping(
|
||||
# monitor="val_dice",
|
||||
# mode="max",
|
||||
# min_delta=0.00,
|
||||
# patience=EARLY_STOP,
|
||||
# )
|
||||
#
|
||||
# # training
|
||||
# trainer = pl.Trainer(
|
||||
# accelerator='gpu',
|
||||
# devices=GPUS,
|
||||
# max_epochs=EPOCHS,
|
||||
# logger=logger,
|
||||
# callbacks=[checkpoint_callback, earlystop_callback]
|
||||
# )
|
||||
#
|
||||
# trainer.fit(
|
||||
# model,
|
||||
# train_loader,
|
||||
# valid_loader
|
||||
# )
|
||||
#
|
||||
# # inference
|
||||
# predictions = trainer.predict(
|
||||
# model=model,
|
||||
# dataloaders=test_loader,
|
||||
# ckpt_path='best'
|
||||
# )
|
||||
#
|
||||
# preds = np.concatenate([test_dataset.spliter.recover(item[0])[np.newaxis, :, :] for item in predictions]).tolist()
|
||||
# labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
|
||||
#
|
||||
# results = {
|
||||
# 'img_path': test_img_list,
|
||||
# 'pred': preds,
|
||||
# 'label': labels,
|
||||
# }
|
||||
#
|
||||
# results_json = json.dumps(results)
|
||||
# with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
|
||||
# f.write(results_json)
|
Loading…
Reference in New Issue