atom-predict/egnn_v2/plot_view.py

413 lines
16 KiB
Python
Raw Normal View History

2024-07-22 08:48:33 +08:00
import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
2024-07-23 13:42:42 +08:00
from egnn_core.data import load_data
from egnn_core.data import load_data_v2
2024-08-02 17:10:06 +08:00
from model_type_dict import ModelLabel
2024-07-22 08:48:33 +08:00
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
2024-08-02 17:10:06 +08:00
from typing import Dict, List
2024-07-22 08:48:33 +08:00
2024-08-02 17:10:06 +08:00
#需要支持不同的size,不止2048
#img.shape 参数
def plot1(json_path, output_path, psize=24, img=None):
bg = np.zeros(img.shape, np.uint8) + 255
bg[0, 0] = 0
points, edge_index, labels, _ = load_data(json_path)
2024-08-02 17:10:06 +08:00
mask_pd = np.zeros(img.shape)
mask_pd[points[:, 0], points[:, 1]] = 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
2024-08-02 17:10:06 +08:00
plt.imshow(img, cmap='gray')
h, w = np.where(mask_pd != 0)
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
for idx, (s, e) in enumerate(edge_index.T):
s = points[s]
e = points[e]
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
json_filename = os.path.splitext(os.path.basename(json_path))[0]
output_file = os.path.join(output_path, f"{json_filename}.jpg")
plt.savefig(output_file)
plt.close()
2024-07-22 08:48:33 +08:00
def plot2(img_folder, json_folder, psize=26, output_folder='output'):
c = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
# 确保输出文件夹存在
if not os.path.exists(output_folder):
2024-07-23 13:42:42 +08:00
os.makedirs(output_folder, exist_ok=True)
2024-07-22 08:48:33 +08:00
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
if img_filename.endswith(".jpg"): # 确保是jpg文件
img_path = os.path.join(img_folder, img_filename)
json_filename = img_filename.replace('.jpg', '.json') # 假设json文件名与jpg文件名相同只是扩展名不同
json_path = os.path.join(json_folder, json_filename)
img = cv2.imread(img_path, 0)
points, edge_index, labels, _ = load_data_v2(json_path)
2024-07-23 13:42:42 +08:00
mask_pd = np.zeros(img.shape)
2024-07-22 08:48:33 +08:00
# mask_pd = np.zeros((512, 512))
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for i in range(4):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
# 保存图像
2024-07-23 13:42:42 +08:00
output_path = os.path.join(output_folder, img_filename)
print(f"Save plot to {output_path}")
2024-07-22 08:48:33 +08:00
plt.savefig(output_path, dpi = 300)
plt.close() # 关闭当前的绘图窗口,以避免内存泄漏
2024-07-26 16:14:40 +08:00
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
# 假设你有一个用于加载JSON数据的函数
def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'):
# 定义颜色数组
colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
if img_filename.endswith(".jpg"): # 确保是jpeg文件
img_path = os.path.join(img_folder, img_filename)
img = cv2.imread(img_path, 0)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for model_idx, json_folder in enumerate(json_folders):
json_filename = img_filename.replace('.jpg', '.json')
print(f"Processing {json_filename}")
json_path = os.path.join(json_folder, json_filename)
if os.path.exists(json_path):
points, edge_index, labels, _ = load_data_v2(json_path)
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
for i in range(4):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5)
plt.axis('off')
plt.tight_layout()
# 保存图像
output_path = os.path.join(output_folder, img_filename)
print(f"Save plot to {output_path}")
plt.savefig(output_path, dpi=300)
plt.close()
2024-08-02 17:10:06 +08:00
def plot_arange_combine(img_folder, json_folders: Dict[str, str], select_labels:List[ModelLabel], output_folder: str, psize=40)->None:
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
c = [
['#FF5733', '#33FF57', '#3357FF'], # Red, Green, Blue
['#FF33FF', '#33FFFF', '#FFFF33'], # Magenta, Cyan, Yellow
['#F08080', '#90EE90', '#ADD8E6'], # Light Coral, Light Green, Light Blue
['#FFD700', '#8A2BE2', '#FF4500'] # Gold, Blue Violet, Orange Red
]
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
plt.figure(figsize=(9, 9))
2024-07-26 16:14:40 +08:00
2024-08-02 17:10:06 +08:00
if img_filename.endswith(".jpg"): # 确保是jpeg文件
img_path = os.path.join(img_folder, img_filename)
img = cv2.imread(img_path, 0)
base_points_indexs = []
for i, select_label in enumerate(select_labels):
# 获取json文件
json_path = os.path.join(json_folders[select_label.name], os.path.splitext(os.path.basename(img_filename))[0] + '.json')
points, edge_index, labels, _ = load_data_v2(json_path)
# 第一个要特殊处理,要把全部标签点都画出来
if i == 0:
# 获取基础标签,可能有多个,所以用数组
for j, label in enumerate(select_label.label):
select_point_index = np.where(labels == select_label.label[j])[0]
base_points_indexs.append(select_point_index)
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for q in range(3):
h, w = np.where(mask_pd == q + 1)
plt.scatter(w, h, s=psize, c=c[i][q])
else:
model_label_index = []
# 获取用户选择的标签对应的点
for k, sub_select_label in enumerate(select_label.label):
tmp = np.where(labels == sub_select_label)[0]
model_label_index.append(tmp)
# 计算补集,去除基础标签的点
for index in model_label_index:
indexs = []
for base_points_index in base_points_indexs:
index = index[~np.isin(index, base_points_index)]
indexs.append(index)
for index in indexs:
labelsv = labels[index]
pointsv = points[index]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for q in range(3):
h, w = np.where(mask_pd == q + 1)
plt.scatter(w, h, s=psize, c=c[i][q])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
#需求:根据用户筛选的标签,画出对应的图像
2024-07-26 16:14:40 +08:00
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
c = ['#dfc8a0', '#814d81', '#debd97']
##814d81紫色
2024-08-02 17:10:06 +08:00
#根据预测结果获取标签
2024-07-26 16:14:40 +08:00
points, edge_index, labels, _ = load_data_v2(json_path)
img = cv2.imread(img_path, 0)
select_point_index = np.where(labels == 1)[0]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
#img = np.array(Image.open(img_path))
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#debd97', '#84aad0', '#debd97']
points, edge_index, labels, _ = load_data_v2(json_path2)
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2(json_path3)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# c = ['#84aad0', '#92b964', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2(json_path4)
#
# VV = np.where(labels == 1)[0]
# VV = VV[~np.isin(VV, select_point_index)]
#
# labels = labels[VV]
# points = points[VV]
# mask_pd = np.zeros((512, 512)) + 256
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
# mask_pd = np.array(mask_pd, np.uint8)
# for i in range(3):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=psize, c=c[i])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
def plot_combine1(img_path, json_paths, psize=40):
colors = [
['#dfc8a0', '#814d81', '#debd97'],
['#debd97', '#84aad0', '#debd97'],
['#84aad0', '#ff5b5b', '#ff5b5b'],
['#84aad0', '#92b964', '#ff5b5b']
]
select_point_index = []
# 初始化绘图
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
plt.imshow(img, cmap='gray')
for idx, json_path in enumerate(json_paths):
points, edge_index, labels, _ = load_data_v2(json_path)
# 仅对于第一个json文件记录选择的点索引
if idx == 0:
select_point_index = np.where(labels == 1)[0]
if idx == 0:
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
if idx in [1, 2]:
labels_v = labels[VV]
points_v = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=colors[idx][i])
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
if __name__ == '__main__':
# 示例调用
try:
folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process'
img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg'
json_folders = [
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
''
]
2024-08-02 17:10:06 +08:00
new_json_folder = [
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
''
]
dict_json_folder = {
"正常掺杂分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process',
"SMoV三种原子分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process',
"012三种不同结构的原子类型分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process',
}
select_labels = [
ModelLabel(name="012三种不同结构的原子类型分类", label=[1]),
ModelLabel(name="SMoV三种原子分类", label=[0,1]),
ModelLabel(name="正常掺杂分类", label=[1]),
]
plot_arange_combine(folder, dict_json_folder, select_labels=select_labels, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view_0731')
2024-07-26 16:14:40 +08:00
#plot2(folder, folder)
2024-08-02 17:10:06 +08:00
#plot_combine(img_folder, new_json_folder[0], new_json_folder[1], new_json_folder[2], new_json_folder[3])
2024-07-26 16:14:40 +08:00
#plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view')
2024-08-02 17:10:06 +08:00
2024-07-26 16:14:40 +08:00
except Exception as e:
import traceback
print(f"Error processing error: {e}")
traceback.print_exc()
2024-07-22 08:48:33 +08:00
# 指定图像和JSON文件的文件夹路径
# img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
# json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
2024-07-23 13:42:42 +08:00
# img_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
# json_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
#
# # 调用函数
# plot2(img_folder, json_folder)
2024-07-22 08:48:33 +08:00
2024-07-23 13:42:42 +08:00
# post_process_save_path = './predict_result/post_process/'
# plot_save_path = './predict_result/predict_result_view/'
# os.makedirs(plot_save_path, exist_ok=True)
# plot2(img_folder=post_process_save_path, json_folder= post_process_save_path, output_folder=plot_save_path)