atom-predict/msunet/plot.py

72 lines
1.7 KiB
Python
Executable File

import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
#from core.data import load_data_plot
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
def load_data(json_path):
data_dict = {
'0': 0,
'1': 1,
'2': 2,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
return points, labels
def plot2(img_path, json_path, save_path=None,psize=20):
c = ['red', 'green', 'yellow']
img = cv2.imread(img_path, 0)
points, labels = load_data(json_path)
mask_pd = np.zeros(img.shape)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
plt.axis('off')
#
plt.tight_layout()
# plt.show()
# 确保结果目录存在
# 获取原图像的文件名
file_name = os.path.basename(img_path)
result_path = os.path.join(save_path, "predict_"+ file_name)
# 保存图像
plt.savefig(result_path)
plt.close()
# json_path = './result/post_process/10.json'
# img_path = './result/post_process/10.jpg'
#
# plot2(img_path, json_path)