413 lines
16 KiB
Python
413 lines
16 KiB
Python
import os
|
||
import cv2
|
||
import glob
|
||
import json
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
|
||
from tqdm import tqdm
|
||
from PIL import Image
|
||
from egnn_core.data import load_data
|
||
from egnn_core.data import load_data_v2
|
||
from model_type_dict import ModelLabel
|
||
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
|
||
from typing import Dict, List
|
||
|
||
#需要支持不同的size,不止2048
|
||
#img.shape 参数
|
||
def plot1(json_path, output_path, psize=24, img=None):
|
||
bg = np.zeros(img.shape, np.uint8) + 255
|
||
bg[0, 0] = 0
|
||
|
||
points, edge_index, labels, _ = load_data(json_path)
|
||
|
||
mask_pd = np.zeros(img.shape)
|
||
mask_pd[points[:, 0], points[:, 1]] = 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
plt.figure(figsize=(9, 9))
|
||
plt.imshow(img, cmap='gray')
|
||
h, w = np.where(mask_pd != 0)
|
||
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
|
||
|
||
for idx, (s, e) in enumerate(edge_index.T):
|
||
s = points[s]
|
||
e = points[e]
|
||
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
|
||
json_filename = os.path.splitext(os.path.basename(json_path))[0]
|
||
output_file = os.path.join(output_path, f"{json_filename}.jpg")
|
||
|
||
plt.savefig(output_file)
|
||
plt.close()
|
||
|
||
|
||
def plot2(img_folder, json_folder, psize=26, output_folder='output'):
|
||
c = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
|
||
|
||
# 确保输出文件夹存在
|
||
if not os.path.exists(output_folder):
|
||
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
# 遍历指定文件夹中的所有文件
|
||
for img_filename in os.listdir(img_folder):
|
||
if img_filename.endswith(".jpg"): # 确保是jpg文件
|
||
img_path = os.path.join(img_folder, img_filename)
|
||
json_filename = img_filename.replace('.jpg', '.json') # 假设json文件名与jpg文件名相同,只是扩展名不同
|
||
json_path = os.path.join(json_folder, json_filename)
|
||
|
||
img = cv2.imread(img_path, 0)
|
||
|
||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||
|
||
mask_pd = np.zeros(img.shape)
|
||
# mask_pd = np.zeros((512, 512))
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
plt.figure(figsize=(9, 9))
|
||
plt.imshow(img, cmap='gray')
|
||
for i in range(4):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
# for idx, (s, e) in enumerate(edge_index.T):
|
||
# s = points[s]
|
||
# e = points[e]
|
||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||
|
||
plt.axis('off')
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
output_path = os.path.join(output_folder, img_filename)
|
||
print(f"Save plot to {output_path}")
|
||
plt.savefig(output_path, dpi = 300)
|
||
plt.close() # 关闭当前的绘图窗口,以避免内存泄漏
|
||
|
||
import os
|
||
import cv2
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from PIL import Image
|
||
import json
|
||
|
||
# 假设你有一个用于加载JSON数据的函数
|
||
|
||
def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'):
|
||
# 定义颜色数组
|
||
colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
|
||
model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
|
||
|
||
# 确保输出文件夹存在
|
||
if not os.path.exists(output_folder):
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
|
||
# 遍历指定文件夹中的所有文件
|
||
for img_filename in os.listdir(img_folder):
|
||
|
||
if img_filename.endswith(".jpg"): # 确保是jpeg文件
|
||
img_path = os.path.join(img_folder, img_filename)
|
||
img = cv2.imread(img_path, 0)
|
||
|
||
plt.figure(figsize=(9, 9))
|
||
plt.imshow(img, cmap='gray')
|
||
|
||
for model_idx, json_folder in enumerate(json_folders):
|
||
json_filename = img_filename.replace('.jpg', '.json')
|
||
print(f"Processing {json_filename}")
|
||
json_path = os.path.join(json_folder, json_filename)
|
||
|
||
if os.path.exists(json_path):
|
||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
|
||
for i in range(4):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5)
|
||
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
output_path = os.path.join(output_folder, img_filename)
|
||
print(f"Save plot to {output_path}")
|
||
plt.savefig(output_path, dpi=300)
|
||
plt.close()
|
||
|
||
def plot_arange_combine(img_folder, json_folders: Dict[str, str], select_labels:List[ModelLabel], output_folder: str, psize=40)->None:
|
||
# 确保输出文件夹存在
|
||
if not os.path.exists(output_folder):
|
||
os.makedirs(output_folder, exist_ok=True)
|
||
c = [
|
||
['#FF5733', '#33FF57', '#3357FF'], # Red, Green, Blue
|
||
['#FF33FF', '#33FFFF', '#FFFF33'], # Magenta, Cyan, Yellow
|
||
['#F08080', '#90EE90', '#ADD8E6'], # Light Coral, Light Green, Light Blue
|
||
['#FFD700', '#8A2BE2', '#FF4500'] # Gold, Blue Violet, Orange Red
|
||
]
|
||
|
||
# 遍历指定文件夹中的所有文件
|
||
for img_filename in os.listdir(img_folder):
|
||
plt.figure(figsize=(9, 9))
|
||
|
||
if img_filename.endswith(".jpg"): # 确保是jpeg文件
|
||
img_path = os.path.join(img_folder, img_filename)
|
||
img = cv2.imread(img_path, 0)
|
||
base_points_indexs = []
|
||
for i, select_label in enumerate(select_labels):
|
||
# 获取json文件
|
||
json_path = os.path.join(json_folders[select_label.name], os.path.splitext(os.path.basename(img_filename))[0] + '.json')
|
||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||
# 第一个要特殊处理,要把全部标签点都画出来
|
||
if i == 0:
|
||
# 获取基础标签,可能有多个,所以用数组
|
||
|
||
for j, label in enumerate(select_label.label):
|
||
select_point_index = np.where(labels == select_label.label[j])[0]
|
||
base_points_indexs.append(select_point_index)
|
||
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
for q in range(3):
|
||
h, w = np.where(mask_pd == q + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i][q])
|
||
else:
|
||
model_label_index = []
|
||
# 获取用户选择的标签对应的点
|
||
for k, sub_select_label in enumerate(select_label.label):
|
||
tmp = np.where(labels == sub_select_label)[0]
|
||
model_label_index.append(tmp)
|
||
|
||
# 计算补集,去除基础标签的点
|
||
for index in model_label_index:
|
||
indexs = []
|
||
for base_points_index in base_points_indexs:
|
||
index = index[~np.isin(index, base_points_index)]
|
||
indexs.append(index)
|
||
|
||
|
||
for index in indexs:
|
||
labelsv = labels[index]
|
||
pointsv = points[index]
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
for q in range(3):
|
||
h, w = np.where(mask_pd == q + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i][q])
|
||
|
||
plt.imshow(img, cmap='gray')
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig('kw.jpg', dpi=300)
|
||
plt.show()
|
||
|
||
|
||
|
||
|
||
|
||
#需求:根据用户筛选的标签,画出对应的图像
|
||
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
|
||
c = ['#dfc8a0', '#814d81', '#debd97']
|
||
##814d81紫色
|
||
#根据预测结果获取标签
|
||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||
img = cv2.imread(img_path, 0)
|
||
select_point_index = np.where(labels == 1)[0]
|
||
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
plt.figure(figsize=(9, 9))
|
||
#img = np.array(Image.open(img_path))
|
||
|
||
for i in range(3):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
c = ['#debd97', '#84aad0', '#debd97']
|
||
|
||
points, edge_index, labels, _ = load_data_v2(json_path2)
|
||
|
||
AA = np.where(labels == 0)[0]
|
||
AA = AA[~np.isin(AA, select_point_index)]
|
||
VV = np.where(labels == 1)[0]
|
||
VV = VV[~np.isin(VV, select_point_index)]
|
||
|
||
labelsv = labels[VV]
|
||
pointsv = points[VV]
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
for i in range(3):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
# VV = np.where(labels == 2)[0]
|
||
labels = labels[AA]
|
||
points = points[AA]
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
for i in range(3):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||
points, edge_index, labels, _ = load_data_v2(json_path3)
|
||
|
||
VV = np.where(labels == 1)[0]
|
||
VV = VV[~np.isin(VV, select_point_index)]
|
||
labels = labels[VV]
|
||
points = points[VV]
|
||
|
||
mask_pd = np.zeros(img.shape, np.uint8)
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
for i in range(3):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
# c = ['#84aad0', '#92b964', '#ff5b5b']
|
||
# points, edge_index, labels, _ = load_data_v2(json_path4)
|
||
#
|
||
# VV = np.where(labels == 1)[0]
|
||
# VV = VV[~np.isin(VV, select_point_index)]
|
||
#
|
||
# labels = labels[VV]
|
||
# points = points[VV]
|
||
# mask_pd = np.zeros((512, 512)) + 256
|
||
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
# mask_pd = np.array(mask_pd, np.uint8)
|
||
# for i in range(3):
|
||
# h, w = np.where(mask_pd == i + 1)
|
||
# plt.scatter(w, h, s=psize, c=c[i])
|
||
|
||
plt.imshow(img, cmap='gray')
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig('kw.jpg', dpi=300)
|
||
plt.show()
|
||
|
||
|
||
def plot_combine1(img_path, json_paths, psize=40):
|
||
colors = [
|
||
['#dfc8a0', '#814d81', '#debd97'],
|
||
['#debd97', '#84aad0', '#debd97'],
|
||
['#84aad0', '#ff5b5b', '#ff5b5b'],
|
||
['#84aad0', '#92b964', '#ff5b5b']
|
||
]
|
||
|
||
select_point_index = []
|
||
|
||
# 初始化绘图
|
||
plt.figure(figsize=(9, 9))
|
||
img = np.array(Image.open(img_path))
|
||
plt.imshow(img, cmap='gray')
|
||
|
||
for idx, json_path in enumerate(json_paths):
|
||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||
|
||
# 仅对于第一个json文件,记录选择的点索引
|
||
if idx == 0:
|
||
select_point_index = np.where(labels == 1)[0]
|
||
|
||
if idx == 0:
|
||
mask_pd = np.zeros((512, 512)) + 256
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
else:
|
||
AA = np.where(labels == 0)[0]
|
||
AA = AA[~np.isin(AA, select_point_index)]
|
||
VV = np.where(labels == 1)[0]
|
||
VV = VV[~np.isin(VV, select_point_index)]
|
||
|
||
if idx in [1, 2]:
|
||
labels_v = labels[VV]
|
||
points_v = points[VV]
|
||
mask_pd = np.zeros((512, 512)) + 256
|
||
mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
else:
|
||
labels = labels[VV]
|
||
points = points[VV]
|
||
mask_pd = np.zeros((512, 512)) + 256
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
for i in range(3):
|
||
h, w = np.where(mask_pd == i + 1)
|
||
plt.scatter(w, h, s=psize, c=colors[idx][i])
|
||
|
||
plt.axis('off')
|
||
plt.tight_layout()
|
||
plt.savefig('kw.jpg', dpi=300)
|
||
plt.show()
|
||
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 示例调用
|
||
try:
|
||
folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process'
|
||
img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg'
|
||
json_folders = [
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
|
||
''
|
||
]
|
||
new_json_folder = [
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
|
||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
|
||
''
|
||
]
|
||
|
||
dict_json_folder = {
|
||
"正常掺杂分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process',
|
||
"SMoV三种原子分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process',
|
||
"012三种不同结构的原子类型分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process',
|
||
}
|
||
|
||
select_labels = [
|
||
ModelLabel(name="012三种不同结构的原子类型分类", label=[1]),
|
||
ModelLabel(name="SMoV三种原子分类", label=[0,1]),
|
||
ModelLabel(name="正常掺杂分类", label=[1]),
|
||
]
|
||
|
||
plot_arange_combine(folder, dict_json_folder, select_labels=select_labels, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view_0731')
|
||
#plot2(folder, folder)
|
||
#plot_combine(img_folder, new_json_folder[0], new_json_folder[1], new_json_folder[2], new_json_folder[3])
|
||
#plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view')
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
print(f"Error processing error: {e}")
|
||
traceback.print_exc()
|
||
|
||
# 指定图像和JSON文件的文件夹路径
|
||
# img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
|
||
# json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
|
||
# img_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
|
||
# json_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
|
||
#
|
||
# # 调用函数
|
||
# plot2(img_folder, json_folder)
|
||
|
||
|
||
# post_process_save_path = './predict_result/post_process/'
|
||
# plot_save_path = './predict_result/predict_result_view/'
|
||
# os.makedirs(plot_save_path, exist_ok=True)
|
||
# plot2(img_folder=post_process_save_path, json_folder= post_process_save_path, output_folder=plot_save_path) |