atom-predict/egnn_v2/plot_view.py

413 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from egnn_core.data import load_data
from egnn_core.data import load_data_v2
from model_type_dict import ModelLabel
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
from typing import Dict, List
#需要支持不同的size,不止2048
#img.shape 参数
def plot1(json_path, output_path, psize=24, img=None):
bg = np.zeros(img.shape, np.uint8) + 255
bg[0, 0] = 0
points, edge_index, labels, _ = load_data(json_path)
mask_pd = np.zeros(img.shape)
mask_pd[points[:, 0], points[:, 1]] = 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
h, w = np.where(mask_pd != 0)
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
for idx, (s, e) in enumerate(edge_index.T):
s = points[s]
e = points[e]
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
json_filename = os.path.splitext(os.path.basename(json_path))[0]
output_file = os.path.join(output_path, f"{json_filename}.jpg")
plt.savefig(output_file)
plt.close()
def plot2(img_folder, json_folder, psize=26, output_folder='output'):
c = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
if img_filename.endswith(".jpg"): # 确保是jpg文件
img_path = os.path.join(img_folder, img_filename)
json_filename = img_filename.replace('.jpg', '.json') # 假设json文件名与jpg文件名相同只是扩展名不同
json_path = os.path.join(json_folder, json_filename)
img = cv2.imread(img_path, 0)
points, edge_index, labels, _ = load_data_v2(json_path)
mask_pd = np.zeros(img.shape)
# mask_pd = np.zeros((512, 512))
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for i in range(4):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
# 保存图像
output_path = os.path.join(output_folder, img_filename)
print(f"Save plot to {output_path}")
plt.savefig(output_path, dpi = 300)
plt.close() # 关闭当前的绘图窗口,以避免内存泄漏
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
# 假设你有一个用于加载JSON数据的函数
def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'):
# 定义颜色数组
colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
if img_filename.endswith(".jpg"): # 确保是jpeg文件
img_path = os.path.join(img_folder, img_filename)
img = cv2.imread(img_path, 0)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for model_idx, json_folder in enumerate(json_folders):
json_filename = img_filename.replace('.jpg', '.json')
print(f"Processing {json_filename}")
json_path = os.path.join(json_folder, json_filename)
if os.path.exists(json_path):
points, edge_index, labels, _ = load_data_v2(json_path)
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
for i in range(4):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5)
plt.axis('off')
plt.tight_layout()
# 保存图像
output_path = os.path.join(output_folder, img_filename)
print(f"Save plot to {output_path}")
plt.savefig(output_path, dpi=300)
plt.close()
def plot_arange_combine(img_folder, json_folders: Dict[str, str], select_labels:List[ModelLabel], output_folder: str, psize=40)->None:
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
c = [
['#FF5733', '#33FF57', '#3357FF'], # Red, Green, Blue
['#FF33FF', '#33FFFF', '#FFFF33'], # Magenta, Cyan, Yellow
['#F08080', '#90EE90', '#ADD8E6'], # Light Coral, Light Green, Light Blue
['#FFD700', '#8A2BE2', '#FF4500'] # Gold, Blue Violet, Orange Red
]
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
plt.figure(figsize=(9, 9))
if img_filename.endswith(".jpg"): # 确保是jpeg文件
img_path = os.path.join(img_folder, img_filename)
img = cv2.imread(img_path, 0)
base_points_indexs = []
for i, select_label in enumerate(select_labels):
# 获取json文件
json_path = os.path.join(json_folders[select_label.name], os.path.splitext(os.path.basename(img_filename))[0] + '.json')
points, edge_index, labels, _ = load_data_v2(json_path)
# 第一个要特殊处理,要把全部标签点都画出来
if i == 0:
# 获取基础标签,可能有多个,所以用数组
for j, label in enumerate(select_label.label):
select_point_index = np.where(labels == select_label.label[j])[0]
base_points_indexs.append(select_point_index)
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for q in range(3):
h, w = np.where(mask_pd == q + 1)
plt.scatter(w, h, s=psize, c=c[i][q])
else:
model_label_index = []
# 获取用户选择的标签对应的点
for k, sub_select_label in enumerate(select_label.label):
tmp = np.where(labels == sub_select_label)[0]
model_label_index.append(tmp)
# 计算补集,去除基础标签的点
for index in model_label_index:
indexs = []
for base_points_index in base_points_indexs:
index = index[~np.isin(index, base_points_index)]
indexs.append(index)
for index in indexs:
labelsv = labels[index]
pointsv = points[index]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for q in range(3):
h, w = np.where(mask_pd == q + 1)
plt.scatter(w, h, s=psize, c=c[i][q])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
#需求:根据用户筛选的标签,画出对应的图像
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
c = ['#dfc8a0', '#814d81', '#debd97']
##814d81紫色
#根据预测结果获取标签
points, edge_index, labels, _ = load_data_v2(json_path)
img = cv2.imread(img_path, 0)
select_point_index = np.where(labels == 1)[0]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
#img = np.array(Image.open(img_path))
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#debd97', '#84aad0', '#debd97']
points, edge_index, labels, _ = load_data_v2(json_path2)
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2(json_path3)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# c = ['#84aad0', '#92b964', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2(json_path4)
#
# VV = np.where(labels == 1)[0]
# VV = VV[~np.isin(VV, select_point_index)]
#
# labels = labels[VV]
# points = points[VV]
# mask_pd = np.zeros((512, 512)) + 256
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
# mask_pd = np.array(mask_pd, np.uint8)
# for i in range(3):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=psize, c=c[i])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
def plot_combine1(img_path, json_paths, psize=40):
colors = [
['#dfc8a0', '#814d81', '#debd97'],
['#debd97', '#84aad0', '#debd97'],
['#84aad0', '#ff5b5b', '#ff5b5b'],
['#84aad0', '#92b964', '#ff5b5b']
]
select_point_index = []
# 初始化绘图
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
plt.imshow(img, cmap='gray')
for idx, json_path in enumerate(json_paths):
points, edge_index, labels, _ = load_data_v2(json_path)
# 仅对于第一个json文件记录选择的点索引
if idx == 0:
select_point_index = np.where(labels == 1)[0]
if idx == 0:
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
if idx in [1, 2]:
labels_v = labels[VV]
points_v = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=colors[idx][i])
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
if __name__ == '__main__':
# 示例调用
try:
folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process'
img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg'
json_folders = [
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
''
]
new_json_folder = [
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
''
]
dict_json_folder = {
"正常掺杂分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process',
"SMoV三种原子分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process',
"012三种不同结构的原子类型分类": '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process',
}
select_labels = [
ModelLabel(name="012三种不同结构的原子类型分类", label=[1]),
ModelLabel(name="SMoV三种原子分类", label=[0,1]),
ModelLabel(name="正常掺杂分类", label=[1]),
]
plot_arange_combine(folder, dict_json_folder, select_labels=select_labels, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view_0731')
#plot2(folder, folder)
#plot_combine(img_folder, new_json_folder[0], new_json_folder[1], new_json_folder[2], new_json_folder[3])
#plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view')
except Exception as e:
import traceback
print(f"Error processing error: {e}")
traceback.print_exc()
# 指定图像和JSON文件的文件夹路径
# img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
# json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
# img_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
# json_folder = '/home/gao/mouclear/cc/final_todo/SV/other_result/after_test_v2_dengbian'
#
# # 调用函数
# plot2(img_folder, json_folder)
# post_process_save_path = './predict_result/post_process/'
# plot_save_path = './predict_result/predict_result_view/'
# os.makedirs(plot_save_path, exist_ok=True)
# plot2(img_folder=post_process_save_path, json_folder= post_process_save_path, output_folder=plot_save_path)