diff --git a/egnn_v2/egnn_core/data.py b/egnn_v2/egnn_core/data.py index 7654a8f..a3e1c63 100644 --- a/egnn_v2/egnn_core/data.py +++ b/egnn_v2/egnn_core/data.py @@ -9,7 +9,7 @@ from PIL import Image from torch_geometric.data import Data from torch_geometric.data import InMemoryDataset, download_url, Dataset from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay -#from metricse2e_vor import norm_line_sv_label +from model_type_dict import norm_line_sv_label import albumentations as A try: @@ -27,7 +27,7 @@ def get_training_augmentation(): return A.Compose(train_transform) def get_validation_augmentation(model_type=None): - if model_type == "点线缺陷分类": + if model_type == norm_line_sv_label: print("线缺陷检测模型需要归一化") test_transform = [ A.Normalize(), diff --git a/egnn_v2/egnn_core/model.py b/egnn_v2/egnn_core/model.py index 71e8448..7abb6d0 100644 --- a/egnn_v2/egnn_core/model.py +++ b/egnn_v2/egnn_core/model.py @@ -1,6 +1,6 @@ import torch.nn as nn from egnn_core.egnn_clean import EGNN -from metricse2e_vor import cz_label +from model_type_dict import cz_label class PL_EGNN(nn.Module): def __init__(self, model_type=None): diff --git a/egnn_v2/metricse2e_vor.py b/egnn_v2/metricse2e_vor.py index 14b46a7..b7576d1 100644 --- a/egnn_v2/metricse2e_vor.py +++ b/egnn_v2/metricse2e_vor.py @@ -1,66 +1,15 @@ import os import shutil -import cv2 import glob import json import numpy as np -import pandas as pd -import matplotlib.pyplot as plt from tqdm import tqdm -from PIL import Image -from egnn_utils.e2e_metrics import get_metrics -from egnn_core.data import get_y_3 from egnn_core.data import load_data -from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix - - -class_dict = { - 0: 'Norm', - 1: 'SV', - 2: 'LineSV', -} - -class_dict_rev = { - 'Norm': 1, - 'SV': 2, - 'LineSV': 3, -} - -#cz: -cz_class_dict = { - 0: 'atom', - 1: 'cz', -} - -#xj: -xj_class_dict = { - 0: '0', - 1: '1', - 2: '2', -} - -#smov: -smov_class_dict = { - 0: 'S', - 1: 'Mo', - 2: 'V' - } - -norm_line_sv_label = "点线缺陷分类" -cz_label = "正常掺杂分类" -xj_label = "012三种不同结构的原子类型分类" -smov_label = "S/Mo/V三种原子分类" - -model_path_dict = { - norm_line_sv_label: "model/last.ckpt", - cz_label: "model/cz_new.ckpt", - xj_label: "model/xj_new_2.ckpt", - smov_label: "model/smov.ckpt" - -} +from model_type_dict import class_dict, cz_class_dict, xj_class_dict, smov_class_dict, norm_line_sv_label, \ + cz_label, xj_label, smov_label def get_class_dict(model_type): if model_type == cz_label: diff --git a/egnn_v2/model_type_dict.py b/egnn_v2/model_type_dict.py new file mode 100644 index 0000000..1bedbab --- /dev/null +++ b/egnn_v2/model_type_dict.py @@ -0,0 +1,37 @@ + + +class_dict = { + 0: 'Norm', + 1: 'SV', + 2: 'LineSV', +} +class_dict_rev = { + 'Norm': 1, + 'SV': 2, + 'LineSV': 3, +} +cz_class_dict = { + 0: 'atom', + 1: 'cz', +} +xj_class_dict = { + 0: '0', + 1: '1', + 2: '2', +} +smov_class_dict = { + 0: 'S', + 1: 'Mo', + 2: 'V' +} +norm_line_sv_label = "点线缺陷分类" +cz_label = "正常掺杂分类" +xj_label = "012三种不同结构的原子类型分类" +smov_label = "SMoV三种原子分类" +model_path_dict = { + norm_line_sv_label: "model/last.ckpt", + cz_label: "model/cz_new.ckpt", + xj_label: "model/xj_new_2.ckpt", + smov_label: "model/smov.ckpt" + +} diff --git a/egnn_v2/plot_view.py b/egnn_v2/plot_view.py index 1f86a5d..085af9d 100644 --- a/egnn_v2/plot_view.py +++ b/egnn_v2/plot_view.py @@ -88,6 +88,217 @@ def plot2(img_folder, json_folder, psize=26, output_folder='output'): plt.savefig(output_path, dpi = 300) plt.close() # 关闭当前的绘图窗口,以避免内存泄漏 +import os +import cv2 +import numpy as np +import matplotlib.pyplot as plt +from PIL import Image +import json + +# 假设你有一个用于加载JSON数据的函数 + +def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'): + # 定义颜色数组 + colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red'] + model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF'] + + # 确保输出文件夹存在 + if not os.path.exists(output_folder): + os.makedirs(output_folder, exist_ok=True) + + # 遍历指定文件夹中的所有文件 + for img_filename in os.listdir(img_folder): + + if img_filename.endswith(".jpg"): # 确保是jpeg文件 + img_path = os.path.join(img_folder, img_filename) + img = cv2.imread(img_path, 0) + + plt.figure(figsize=(9, 9)) + plt.imshow(img, cmap='gray') + + for model_idx, json_folder in enumerate(json_folders): + json_filename = img_filename.replace('.jpg', '.json') + print(f"Processing {json_filename}") + json_path = os.path.join(json_folder, json_filename) + + if os.path.exists(json_path): + points, edge_index, labels, _ = load_data_v2(json_path) + mask_pd = np.zeros(img.shape, np.uint8) + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + + for i in range(4): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5) + + plt.axis('off') + plt.tight_layout() + + # 保存图像 + output_path = os.path.join(output_folder, img_filename) + print(f"Save plot to {output_path}") + plt.savefig(output_path, dpi=300) + plt.close() + + +def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40): + c = ['#dfc8a0', '#814d81', '#debd97'] + ##814d81紫色 + points, edge_index, labels, _ = load_data_v2(json_path) + img = cv2.imread(img_path, 0) + select_point_index = np.where(labels == 1)[0] + + mask_pd = np.zeros(img.shape, np.uint8) + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + #img = np.array(Image.open(img_path)) + + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + c = ['#debd97', '#84aad0', '#debd97'] + + points, edge_index, labels, _ = load_data_v2(json_path2) + + AA = np.where(labels == 0)[0] + AA = AA[~np.isin(AA, select_point_index)] + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + labelsv = labels[VV] + pointsv = points[VV] + mask_pd = np.zeros(img.shape, np.uint8) + mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # VV = np.where(labels == 2)[0] + labels = labels[AA] + points = points[AA] + mask_pd = np.zeros(img.shape, np.uint8) + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + c = ['#84aad0', '#ff5b5b', '#ff5b5b'] + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + points, edge_index, labels, _ = load_data_v2(json_path3) + + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + labels = labels[VV] + points = points[VV] + + mask_pd = np.zeros(img.shape, np.uint8) + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # c = ['#84aad0', '#92b964', '#ff5b5b'] + # points, edge_index, labels, _ = load_data_v2(json_path4) + # + # VV = np.where(labels == 1)[0] + # VV = VV[~np.isin(VV, select_point_index)] + # + # labels = labels[VV] + # points = points[VV] + # mask_pd = np.zeros((512, 512)) + 256 + # mask_pd[points[:, 0], points[:, 1]] = labels + 1 + # mask_pd = np.array(mask_pd, np.uint8) + # for i in range(3): + # h, w = np.where(mask_pd == i + 1) + # plt.scatter(w, h, s=psize, c=c[i]) + + plt.imshow(img, cmap='gray') + plt.axis('off') + plt.tight_layout() + plt.savefig('kw.jpg', dpi=300) + plt.show() + + +def plot_combine1(img_path, json_paths, psize=40): + colors = [ + ['#dfc8a0', '#814d81', '#debd97'], + ['#debd97', '#84aad0', '#debd97'], + ['#84aad0', '#ff5b5b', '#ff5b5b'], + ['#84aad0', '#92b964', '#ff5b5b'] + ] + + select_point_index = [] + + # 初始化绘图 + plt.figure(figsize=(9, 9)) + img = np.array(Image.open(img_path)) + plt.imshow(img, cmap='gray') + + for idx, json_path in enumerate(json_paths): + points, edge_index, labels, _ = load_data_v2(json_path) + + # 仅对于第一个json文件,记录选择的点索引 + if idx == 0: + select_point_index = np.where(labels == 1)[0] + + if idx == 0: + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + else: + AA = np.where(labels == 0)[0] + AA = AA[~np.isin(AA, select_point_index)] + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + if idx in [1, 2]: + labels_v = labels[VV] + points_v = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1 + mask_pd = np.array(mask_pd, np.uint8) + else: + labels = labels[VV] + points = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=colors[idx][i]) + + plt.axis('off') + plt.tight_layout() + plt.savefig('kw.jpg', dpi=300) + plt.show() + + + +if __name__ == '__main__': +# 示例调用 + try: + folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process' + img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg' + json_folders = [ + '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json', + '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json', + '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json', + '' + ] + #plot2(folder, folder) + plot_combine(img_folder, json_folders[0], json_folders[1], json_folders[2], json_folders[3]) + #plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view') + except Exception as e: + import traceback + print(f"Error processing error: {e}") + traceback.print_exc() + # 指定图像和JSON文件的文件夹路径 # img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv' # json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv' diff --git a/egnn_v2/plot_vor_pingjie4_auuuug(1).py b/egnn_v2/plot_vor_pingjie4_auuuug(1).py new file mode 100644 index 0000000..506b905 --- /dev/null +++ b/egnn_v2/plot_vor_pingjie4_auuuug(1).py @@ -0,0 +1,199 @@ +import os +import cv2 +import glob +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from tqdm import tqdm +from PIL import Image +from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw +from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix + + +def plot1(json_path, psize=24, img_size=2048): + bg = np.zeros((img_size, img_size)) + 255 + bg[0, 0] = 0 + + points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28) + + mask_pd = np.zeros((img_size, img_size)) + mask_pd[points[:, 0], points[:, 1]] = 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + plt.imshow(bg, cmap='gray') + h, w = np.where(mask_pd != 0) + plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2) + # + for idx, (s, e) in enumerate(edge_index.T): + s = points[s] + e = points[e] + plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + plt.axis('off') + + plt.tight_layout() + + +def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40): + # c = ['red', 'green', 'yellow'] + # c = [ '#dfc8a0', '#84aad0','#93ba66'] + + # c = [ '#dfc8a0','#debd97', '#84aad0'] + + # c = ['#9BB6CF', '#9BB6CF', '#9BB6CF'] + # c = ['#83a2c1', '#9BB6CF', '#9BB6CF'] + + # points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25) + c = ['#dfc8a0', '#84aad0', '#93ba66'] + points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28) + + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31) + + # mask_pd = np.zeros((2048, 2048)) + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + + img = np.array(Image.open(img_path)) + # bg = np.zeros_like(img) + 255 + # bg[0, 0] = 0 + # plt.imshow(bg, cmap='gray') + + + + # for idx, (s, e) in enumerate(edge_index.T): + # s = points[s] + # e = points[e] + # plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + # for i in range(3): + # h, w = np.where(mask_pd == i + 1) + # plt.scatter(w, h, s=psize, c=c[i]) + # + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # plt.show() + + # xj_points = points + + # c = ['#84aad0', '#92b964', '#84aad0'] + + + # # + # for idx, (s, e) in enumerate(edge_index.T): + # s = points[s] + # e = points[e] + # plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + + + + # c = [ '#84aad0','#92b964','#92b964'] + #debd97 黄 92b964 绿 84aad0 蓝 + # c = [ '#92b964','#debd97','#debd97'] + c = [ '#debd97','#84aad0','#debd97'] + points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + + AA = np.where(labels == 0)[0] + + VV = np.where(labels == 1)[0] + labelsv = labels[VV] + pointsv = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # VV = np.where(labels == 2)[0] + labels = labels[AA] + points = points[AA] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # plt.show() + c = ['#84aad0', '#92b964', '#ff5b5b'] + points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28) + + VV = np.where(labels == 1)[0] + labels = labels[VV] + points = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + + c = ['#84aad0', '#ff5b5b', '#ff5b5b'] + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28) + + # a = [] + # b = [] + # + # for i in points: + # if i not in xj_points: + # print(i) + VV = np.where(labels == 1)[0] + labels = labels[VV] + points = points[VV] + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + + + + # + # + + plt.imshow(img, cmap='gray') + plt.axis('off') + plt.tight_layout() + + plt.savefig('49.jpg', dpi=300) + plt.show() + + +# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json' +# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg' +# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json' +# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg' +# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json' +# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg' +# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json' +# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg' + +json_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.json' +img_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.jpg' +json_path2 = '/home/gao/mouclear/cc/final_todo/fig_50/jj_result/raw/49.json' +json_path3 = '/home/gao/mouclear/cc/final_todo/fig_50/cz_result/raw/49.json' +json_path4 = '/home/gao/mouclear/cc/final_todo/kw_49_v2/49.json' + + + +plot2(img_path1,json_path1, json_path2, json_path3, json_path4) \ No newline at end of file diff --git a/egnn_v2/plot_vor_pingjie4_auuuug_53.py b/egnn_v2/plot_vor_pingjie4_auuuug_53.py new file mode 100644 index 0000000..7481fb1 --- /dev/null +++ b/egnn_v2/plot_vor_pingjie4_auuuug_53.py @@ -0,0 +1,325 @@ +import os +import cv2 +import glob +import json +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt + +from tqdm import tqdm +from PIL import Image +from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw +from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix + + +def plot1(json_path, psize=24, img_size=2048): + bg = np.zeros((img_size, img_size)) + 255 + bg[0, 0] = 0 + + points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28) + + mask_pd = np.zeros((img_size, img_size)) + mask_pd[points[:, 0], points[:, 1]] = 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + plt.imshow(bg, cmap='gray') + h, w = np.where(mask_pd != 0) + plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2) + # + for idx, (s, e) in enumerate(edge_index.T): + s = points[s] + e = points[e] + plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + plt.axis('off') + + plt.tight_layout() + +def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40): + c = ['#dfc8a0', '#814d81', '#debd97'] + ##814d81紫色 + points, edge_index, labels, _ = load_data_v2(json_path) + + select_point_index = np.where(labels == 1)[0] + + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + img = np.array(Image.open(img_path)) + + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + c = [ '#debd97','#84aad0','#debd97'] + + points, edge_index, labels, _ = load_data_v2(json_path2) + + AA = np.where(labels == 0)[0] + AA = AA[~np.isin(AA, select_point_index)] + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + + labelsv = labels[VV] + pointsv = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # VV = np.where(labels == 2)[0] + labels = labels[AA] + points = points[AA] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + + c = ['#84aad0', '#ff5b5b', '#ff5b5b'] + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + points, edge_index, labels, _ = load_data_v2(json_path3) + + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + labels = labels[VV] + points = points[VV] + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + c = ['#84aad0', '#92b964', '#ff5b5b'] + points, edge_index, labels, _ = load_data_v2(json_path4) + + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + labels = labels[VV] + points = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + plt.imshow(img, cmap='gray') + plt.axis('off') + plt.tight_layout() + plt.savefig('kw.jpg', dpi=300) + plt.show() + + + +def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40): + # c = ['red', 'green', 'yellow'] + # c = [ '#dfc8a0', '#84aad0','#93ba66'] + + # c = [ '#dfc8a0','#debd97', '#84aad0'] + + # c = ['#9BB6CF', '#9BB6CF', '#9BB6CF'] + # c = ['#83a2c1', '#9BB6CF', '#9BB6CF'] + + # points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25) + # c = ['#dfc8a0', '#84aad0', '#93ba66'] + # c = ['#dfc8a0', '#814d81', '#debd97'] + c = ['#dfc8a0', '#814d81', '#debd97'] + ##814d81紫色 + points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28) + + select_point_index = np.where(labels == 1)[0] + + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31) + + # mask_pd = np.zeros((2048, 2048)) + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + + plt.figure(figsize=(9, 9)) + img = np.array(Image.open(img_path)) + # bg = np.zeros_like(img) + 255 + # bg[0, 0] = 0 + # plt.imshow(bg, cmap='gray') + + # for idx, (s, e) in enumerate(edge_index.T): + # s = points[s] + # e = points[e] + # plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # plt.show() + + + + # # + # for idx, (s, e) in enumerate(edge_index.T): + # s = points[s] + # e = points[e] + # plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1) + + # c = [ '#84aad0','#92b964','#92b964'] + #debd97 黄 92b964 绿 84aad0 蓝 + # c = [ '#92b964','#debd97','#debd97'] + c = [ '#debd97','#84aad0','#debd97'] + + points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + + AA = np.where(labels == 0)[0] + AA = AA[~np.isin(AA, select_point_index)] + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + + labelsv = labels[VV] + pointsv = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # VV = np.where(labels == 2)[0] + labels = labels[AA] + points = points[AA] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + + c = ['#84aad0', '#ff5b5b', '#ff5b5b'] + # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28) + + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + labels = labels[VV] + points = points[VV] + + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # + # plt.savefig('cz.jpg', dpi=300) + # plt.show() + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # + # plt.savefig('JJ.jpg', dpi=300) + # plt.show() + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # plt.show() + c = ['#84aad0', '#92b964', '#ff5b5b'] + points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28) + + VV = np.where(labels == 1)[0] + VV = VV[~np.isin(VV, select_point_index)] + + labels = labels[VV] + points = points[VV] + mask_pd = np.zeros((512, 512)) + 256 + mask_pd[points[:, 0], points[:, 1]] = labels + 1 + mask_pd = np.array(mask_pd, np.uint8) + for i in range(3): + h, w = np.where(mask_pd == i + 1) + plt.scatter(w, h, s=psize, c=c[i]) + + plt.imshow(img, cmap='gray') + plt.axis('off') + plt.tight_layout() + plt.savefig('kw.jpg', dpi=300) + plt.show() + # + # + # c = ['#84aad0', '#ff5b5b', '#ff5b5b'] + # # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28) + # points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28) + # + # # a = [] + # # b = [] + # # + # # for i in points: + # # if i not in xj_points: + # # print(i) + # VV = np.where(labels == 1)[0] + # labels = labels[VV] + # points = points[VV] + # + # mask_pd = np.zeros((512, 512)) + 256 + # mask_pd[points[:, 0], points[:, 1]] = labels + 1 + # mask_pd = np.array(mask_pd, np.uint8) + # for i in range(3): + # h, w = np.where(mask_pd == i + 1) + # plt.scatter(w, h, s=psize, c=c[i]) + + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + + + + # + # + # + # plt.imshow(img, cmap='gray') + # plt.axis('off') + # plt.tight_layout() + # + # plt.savefig('49.jpg', dpi=300) + # plt.show() + + +# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json' +# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg' +# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json' +# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg' +# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json' +# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg' +# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json' +# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg' +img_path1 = '/home/gao/mouclear/cc/final_todo/53/cz/raw/53.jpg' +json_path1 = '/home/gao/mouclear/cc/final_todo/53/xj_norm/raw/53.json' +json_path2 = '/home/gao/mouclear/cc/final_todo/53/jj_nonorm_WO3/raw/53.json' +json_path3 = '/home/gao/mouclear/cc/final_todo/53/cz_mask/raw/mask.json' +# json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/53.json' +json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/mask.json' + + + +plot2(img_path1,json_path1, json_path2, json_path3, json_path4) \ No newline at end of file diff --git a/egnn_v2/predict_pl_v2_aug.py b/egnn_v2/predict_pl_v2_aug.py index 980c3de..206a621 100644 --- a/egnn_v2/predict_pl_v2_aug.py +++ b/egnn_v2/predict_pl_v2_aug.py @@ -27,7 +27,8 @@ from torch_geometric.data.lightning import LightningDataset from torch.utils.data import ConcatDataset from egnn_utils.save import save_results from resize import resize_images -from metricse2e_vor import post_process, norm_line_sv_label, cz_label +from metricse2e_vor import post_process +from model_type_dict import norm_line_sv_label, cz_label from plot_view import plot2, plot1 from dp.launching.report import Report, ReportSection, AutoReportElement # constants @@ -219,14 +220,14 @@ def predict(model_name, test_path, save_path, edges_length=35.5, model_type=None return save_predict_result(predictions, save_path) -def create_save_path(base_path: str) -> Dict[str, str]: +def create_save_path(base_path: str, model_type: str) -> Dict[str, str]: # 定义需要创建的目录结构 paths = { - "gnn_predict_dataset": os.path.join(base_path, "gnn_predict_dataset"), - "gnn_predict_result" : os.path.join(base_path, 'gnn_predict_json'), - "gnn_predict_post_process": os.path.join(base_path, "gnn_predict_post_process"), - "gnn_predict_connect_view": os.path.join(base_path, "gnn_predict_connect_view"), - "gnn_predict_result_view": os.path.join(base_path, "gnn_predict_result_view"), + "gnn_predict_dataset": os.path.join(base_path, model_type,"gnn_predict_dataset"), + "gnn_predict_result" : os.path.join(base_path, model_type, 'gnn_predict_json'), + "gnn_predict_post_process": os.path.join(base_path, model_type, "gnn_predict_post_process"), + "gnn_predict_connect_view": os.path.join(base_path, model_type, "gnn_predict_connect_view"), + "gnn_predict_result_view": os.path.join(base_path, model_type, "gnn_predict_result_view"), } # 创建目录 @@ -288,7 +289,7 @@ def plot_connect_line(json_folder: str, output_folder: str): def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5, model_type=norm_line_sv_label): print("gnn model: ", model_name) #创建保存路径 - save_path = create_save_path(save_base_path) + save_path = create_save_path(save_base_path, model_type) # 数据预处理, resize图片 resize_images(512,256,test_path) #预测结果 diff --git a/msunet/app.py b/msunet/app.py index fdc8642..d575d3b 100644 --- a/msunet/app.py +++ b/msunet/app.py @@ -32,7 +32,8 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn print(sys.path) from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report from train_pl_v2_aug import train as egnn_train -from metricse2e_vor import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict +from model_type_dict import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict +from plot_view import plot_multiple_annotations class PredictOptions(BaseModel): data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集") diff --git a/msunet/test.json b/msunet/test.json index 1e657a0..a7c8e9a 100644 --- a/msunet/test.json +++ b/msunet/test.json @@ -1,10 +1,16 @@ { - "data_path": "sing.zip", + "data_path": "combine.zip", "edges_length": 28, - "output_dir": "./result_norm_new", + "output_dir": "./result_combine_4", "select_models": [ { - "model_name": "点线缺陷分类" + "model_name": "012三种不同结构的原子类型分类" + }, + { + "model_name": "正常掺杂分类" + }, + { + "model_name": "SMoV三种原子分类" } ] } \ No newline at end of file diff --git a/msunet/train_pl.py b/msunet/train_pl.py index 0267225..a0812ee 100755 --- a/msunet/train_pl.py +++ b/msunet/train_pl.py @@ -2,6 +2,7 @@ import argparse import os import json import glob +import shutil import zipfile import torch @@ -203,8 +204,17 @@ def train(process_path: str, model_save_path: str)->str: valid_loader ) - model_name = os.path.join(model_save_path, 'last.ckpt') - trainer.save_checkpoint(model_name) + # model_name = os.path.join(model_save_path, 'last.ckpt') + # trainer.save_checkpoint(model_name) + + last_checkpoint_path = checkpoint_callback.last_model_path + + if last_checkpoint_path: + model_name = os.path.join(model_save_path, 'last.ckpt') + # Copy the last checkpoint to the desired path + shutil.copy(last_checkpoint_path, model_name) + else: + print("No checkpoint was saved by the checkpoint callback.") return model_name # # inference @@ -247,23 +257,29 @@ def create_save_path(base_path: str) -> Dict[str, str]: def arg_parse()->argparse.Namespace: # # 增加参数一个是数据集的路径,另外一个是保存的路径 parser = argparse.ArgumentParser(description='Process some integers.') - parser.add_argument('--data_path', type=str, default='', help='path to dataset') - parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result') + parser.add_argument('--dataset', default="", help='path to test dataset') + parser.add_argument('--model_output', default="", help='path to model result path') + # parser.add_argument('--data_path', type=str, default='', help='path to dataset') + # parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result') args = parser.parse_args() return args if __name__ == '__main__': args = arg_parse() #创建目录 - save_path = create_save_path(args.save_path) + save_path = create_save_path(args.model_output) print("save path: ", save_path) original_path = save_path['train_original'] train_pre_process_path = save_path['train_pre_process'] model_save_path = save_path['model_save_path'] + #从dataset目录获取数据集名称 + for file in os.listdir(args.dataset): + if file.endswith('.zip'): + zip_file_path = os.path.join(args.dataset, file) # 解压数据集到original_path - with zipfile.ZipFile(args.data_path, 'r') as zip_ref: + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: zip_ref.extractall(original_path) #image_path = './train_and_test' @@ -273,10 +289,10 @@ if __name__ == '__main__': model_path = train(train_pre_process_path, model_save_path) print("test start") - print("test start") - test_path= os.path.join(original_path, 'test') - predict_and_plot(model_path, test_path, args.save_path, True) - print("test end") + # print("test start") + # test_path= os.path.join(original_path, 'test') + # predict_and_plot(model_path, test_path, args.save_path, True) + # print("test end") # if __name__ == '__main__':