修改成正确的路径

This commit is contained in:
somunslotus 2024-07-10 09:18:15 +08:00
parent 3bcb595270
commit 4c5599be29
4 changed files with 59 additions and 53 deletions

View File

@ -11,10 +11,10 @@ from skimage.measure import label
from skimage.measure import regionprops
from utils.labelme import save_pred_to_json,save_pred_to_json
json_path = '/home/gao/mouclear/cc/code_v2/msunet/logs/0/version_3/test.json'
json_path = './result/predict/test.json'
# save_path = '/home/gao/mouclear/cc/data/detect_50/替换图像'
# json_path = '/home/gao/mouclear/cc/code/pd_2/code/egnn/logs/0/version_2/test.json'
save_path = '/home/gao/mouclear/cc/data_new/analyze_50'
save_path = './result/post_process/'
os.makedirs(save_path, exist_ok=True)
save_pred_to_json(json_path, save_path)

View File

@ -94,6 +94,7 @@ def process_slide(img_path, save_path):
img_lst = glob.glob('/home/gao/mouclear/cc/data_new/msunet/train_and_test/test/*.jpg')
img_lst = glob.glob('./train_and_test/test/*.jpg')
for item in img_lst:
process_slide(item, save_path='/home/gao/mouclear/cc/data_new/msunet/train_and_test/test/')
process_slide(item, save_path='./result/pre_process/')

View File

@ -2,6 +2,7 @@ import os
import json
import glob
import torch
torch.set_float32_matmul_precision('high')
import numpy as np
@ -43,7 +44,8 @@ EARLY_STOP = 5
EVERY_N_EPOCHS = 1
IMAGE_PATH = '../../data/linesv/patch_unet/'
# IMAGE_PATH2 = '/home/gao/下载/process/test/'
WEIGHTS = torch.FloatTensor([1./8, 1./4, 1./2, 1.])
WEIGHTS = torch.FloatTensor([1. / 8, 1. / 4, 1. / 2, 1.])
# pytorch lightning module
@ -95,19 +97,21 @@ class FCRN(pl.LightningModule):
return pred.squeeze(), lbl
# main
if __name__ == '__main__':
train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
test_img_list = np.array(glob.glob('/home/gao/mouclear/cc/data_new/analyze_50/*.jpg')).tolist()
# train_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'train'))).tolist()
# valid_img_list = np.array(glob.glob('{}/{}/img/*.png'.format(IMAGE_PATH, 'valid'))).tolist()
test_img_list = np.array(glob.glob('./result/pre_process/img/*.png')).tolist()
# test_img_list = np.array(glob.glob('/home/gao/mouclear/detect_40/40/*.jpg')).tolist()
print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list), len(test_img_list)))
# print('Train nums: {}, Valid nums: {}, Test nums: {}.'.format(len(train_img_list), len(valid_img_list), len(test_img_list)))
print('Test nums: ', len(test_img_list))
# train_dataset = MyDataset(train_img_list, dim=DIM, sigma=SIGMA, data_type='train')
# valid_dataset = MyDataset(valid_img_list, dim=DIM, sigma=SIGMA, data_type='valid')
test_dataset = MyDatasetSlide_test_uneuqal(test_img_list, ps=patch_size,roi=roi_size)
test_dataset = MyDatasetSlide_test_uneuqal(test_img_list, ps=patch_size, roi=roi_size)
# train_loader = DataLoader(
# dataset = train_dataset,
@ -125,15 +129,14 @@ if __name__ == '__main__':
# )
test_loader = DataLoader(
dataset = test_dataset,
batch_size = 1,
shuffle = False,
num_workers = 1,
dataset=test_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
)
model = FCRN(IN_CHANNELS)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)
# checkpoint = torch.load('/home/gao/mouclear/cc/code/msunet/logs/0/version_0/checkpoints/last.ckpt', map_location=device)
@ -153,35 +156,34 @@ if __name__ == '__main__':
# 现在 new_state_dict 包含了更新后的键,可以安全地加载到模型中
# model.load_state_dict(state_dict) # 使用 strict=False 以忽略不匹配的键
logger = TensorBoardLogger(
name = DATASETS,
save_dir = 'logs',
name=DATASETS,
save_dir='logs',
)
checkpoint_callback = ModelCheckpoint(
every_n_epochs = EVERY_N_EPOCHS,
save_top_k = SAVE_TOP_K,
monitor = 'val_dice',
mode = 'max',
save_last = True,
filename = '{epoch}-{val_loss:.2f}-{val_dice:.2f}'
every_n_epochs=EVERY_N_EPOCHS,
save_top_k=SAVE_TOP_K,
monitor='val_dice',
mode='max',
save_last=True,
filename='{epoch}-{val_loss:.2f}-{val_dice:.2f}'
)
earlystop_callback = EarlyStopping(
monitor = "val_dice",
mode = "max",
min_delta = 0.00,
patience = EARLY_STOP,
monitor="val_dice",
mode="max",
min_delta=0.00,
patience=EARLY_STOP,
)
# training
trainer = pl.Trainer(
accelerator = 'gpu',
devices = GPUS,
max_epochs = EPOCHS,
logger = logger,
callbacks = [checkpoint_callback, earlystop_callback],
accelerator='gpu',
devices=GPUS,
max_epochs=EPOCHS,
logger=logger,
callbacks=[checkpoint_callback, earlystop_callback],
)
# trainer.fit(
@ -192,24 +194,25 @@ if __name__ == '__main__':
# inference
predictions = trainer.predict(
model = model,
dataloaders = test_loader,
ckpt_path = '/home/gao/mouclear/cc/code/msunet/logs/0/version_1/checkpoints/last.ckpt'
model=model,
dataloaders=test_loader,
ckpt_path='./result/train/last.ckpt'
)
# pre = = spliter.recover(patches, item[1].shape[1], item[1].shape[2], ps=patch_size,roi=roi_size)
preds = np.concatenate([test_dataset.spliter.recover(item[0], item[1].shape[1], item[1].shape[2], ps=patch_size,roi=roi_size)[np.newaxis, :, :] for item in predictions]).tolist()
preds = np.concatenate([test_dataset.spliter.recover(item[0], item[1].shape[1], item[1].shape[2], ps=patch_size,
roi=roi_size)[np.newaxis, :, :] for item in
predictions]).tolist()
labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
results ={
results = {
'img_path': test_img_list,
'pred': preds,
'label': labels,
}
predict_dir= './result/predict/'
results_json = json.dumps(results)
with open(os.path.join(trainer.log_dir, 'test.json'), 'w+') as f:
with open(os.path.join(predict_dir, 'test.json'), 'w+') as f:
f.write(results_json)

View File

@ -27,6 +27,7 @@ class_dict_rev = {
}
def get_json(img_path, lbl):
print("get json image path:", img_path)
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
w, h = np.where(lbl != 0)
@ -34,7 +35,7 @@ def get_json(img_path, lbl):
points = np.array(points, np.int16).tolist()
shapes = [{"label": class_dict[lbl[item[1], item[0]]], "points": [item], "group_id": None, "shape_type": "point", "flags": {}} for item in points]
imagePath = img_path.split('/')[-1]
imageData = utils.img_arr_to_b64(img).decode('utf-8')
imageData = utils.img_arr_to_b64(img) #.decode('utf-8')
imageHeight, imageWidth = img.shape
json_data = {
@ -168,11 +169,12 @@ def save_pred_to_json(json_path, save_path):
pred = np.array(data['pred'])
lb = np.array(data['label'], dtype=np.uint8)
img_path = np.array(data['img_path'])
print("image path: ", img_path)
nums = len(img_path)
for i in range(nums):
name = img_path[i].split('/')[-1].split('.')[0]
print("name: ", name)
# gt
Image.fromarray(lb[i]).save('{}/{}.png'.format(save_path, name))
Image.open(img_path[i]).save('{}/{}.jpg'.format(save_path, name))