组合模式画图

This commit is contained in:
somunslotus 2024-07-26 16:14:40 +08:00
parent a9fd74feb9
commit 02009a8e52
11 changed files with 823 additions and 78 deletions

View File

@ -9,7 +9,7 @@ from PIL import Image
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, download_url, Dataset
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
#from metricse2e_vor import norm_line_sv_label
from model_type_dict import norm_line_sv_label
import albumentations as A
try:
@ -27,7 +27,7 @@ def get_training_augmentation():
return A.Compose(train_transform)
def get_validation_augmentation(model_type=None):
if model_type == "点线缺陷分类":
if model_type == norm_line_sv_label:
print("线缺陷检测模型需要归一化")
test_transform = [
A.Normalize(),

View File

@ -1,6 +1,6 @@
import torch.nn as nn
from egnn_core.egnn_clean import EGNN
from metricse2e_vor import cz_label
from model_type_dict import cz_label
class PL_EGNN(nn.Module):
def __init__(self, model_type=None):

View File

@ -1,66 +1,15 @@
import os
import shutil
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from egnn_utils.e2e_metrics import get_metrics
from egnn_core.data import get_y_3
from egnn_core.data import load_data
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
class_dict = {
0: 'Norm',
1: 'SV',
2: 'LineSV',
}
class_dict_rev = {
'Norm': 1,
'SV': 2,
'LineSV': 3,
}
#cz:
cz_class_dict = {
0: 'atom',
1: 'cz',
}
#xj:
xj_class_dict = {
0: '0',
1: '1',
2: '2',
}
#smov:
smov_class_dict = {
0: 'S',
1: 'Mo',
2: 'V'
}
norm_line_sv_label = "点线缺陷分类"
cz_label = "正常掺杂分类"
xj_label = "012三种不同结构的原子类型分类"
smov_label = "S/Mo/V三种原子分类"
model_path_dict = {
norm_line_sv_label: "model/last.ckpt",
cz_label: "model/cz_new.ckpt",
xj_label: "model/xj_new_2.ckpt",
smov_label: "model/smov.ckpt"
}
from model_type_dict import class_dict, cz_class_dict, xj_class_dict, smov_class_dict, norm_line_sv_label, \
cz_label, xj_label, smov_label
def get_class_dict(model_type):
if model_type == cz_label:

View File

@ -0,0 +1,37 @@
class_dict = {
0: 'Norm',
1: 'SV',
2: 'LineSV',
}
class_dict_rev = {
'Norm': 1,
'SV': 2,
'LineSV': 3,
}
cz_class_dict = {
0: 'atom',
1: 'cz',
}
xj_class_dict = {
0: '0',
1: '1',
2: '2',
}
smov_class_dict = {
0: 'S',
1: 'Mo',
2: 'V'
}
norm_line_sv_label = "点线缺陷分类"
cz_label = "正常掺杂分类"
xj_label = "012三种不同结构的原子类型分类"
smov_label = "SMoV三种原子分类"
model_path_dict = {
norm_line_sv_label: "model/last.ckpt",
cz_label: "model/cz_new.ckpt",
xj_label: "model/xj_new_2.ckpt",
smov_label: "model/smov.ckpt"
}

View File

@ -88,6 +88,217 @@ def plot2(img_folder, json_folder, psize=26, output_folder='output'):
plt.savefig(output_path, dpi = 300)
plt.close() # 关闭当前的绘图窗口,以避免内存泄漏
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import json
# 假设你有一个用于加载JSON数据的函数
def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'):
# 定义颜色数组
colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
# 确保输出文件夹存在
if not os.path.exists(output_folder):
os.makedirs(output_folder, exist_ok=True)
# 遍历指定文件夹中的所有文件
for img_filename in os.listdir(img_folder):
if img_filename.endswith(".jpg"): # 确保是jpeg文件
img_path = os.path.join(img_folder, img_filename)
img = cv2.imread(img_path, 0)
plt.figure(figsize=(9, 9))
plt.imshow(img, cmap='gray')
for model_idx, json_folder in enumerate(json_folders):
json_filename = img_filename.replace('.jpg', '.json')
print(f"Processing {json_filename}")
json_path = os.path.join(json_folder, json_filename)
if os.path.exists(json_path):
points, edge_index, labels, _ = load_data_v2(json_path)
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
for i in range(4):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5)
plt.axis('off')
plt.tight_layout()
# 保存图像
output_path = os.path.join(output_folder, img_filename)
print(f"Save plot to {output_path}")
plt.savefig(output_path, dpi=300)
plt.close()
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
c = ['#dfc8a0', '#814d81', '#debd97']
##814d81紫色
points, edge_index, labels, _ = load_data_v2(json_path)
img = cv2.imread(img_path, 0)
select_point_index = np.where(labels == 1)[0]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
#img = np.array(Image.open(img_path))
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#debd97', '#84aad0', '#debd97']
points, edge_index, labels, _ = load_data_v2(json_path2)
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2(json_path3)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros(img.shape, np.uint8)
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# c = ['#84aad0', '#92b964', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2(json_path4)
#
# VV = np.where(labels == 1)[0]
# VV = VV[~np.isin(VV, select_point_index)]
#
# labels = labels[VV]
# points = points[VV]
# mask_pd = np.zeros((512, 512)) + 256
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
# mask_pd = np.array(mask_pd, np.uint8)
# for i in range(3):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=psize, c=c[i])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
def plot_combine1(img_path, json_paths, psize=40):
colors = [
['#dfc8a0', '#814d81', '#debd97'],
['#debd97', '#84aad0', '#debd97'],
['#84aad0', '#ff5b5b', '#ff5b5b'],
['#84aad0', '#92b964', '#ff5b5b']
]
select_point_index = []
# 初始化绘图
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
plt.imshow(img, cmap='gray')
for idx, json_path in enumerate(json_paths):
points, edge_index, labels, _ = load_data_v2(json_path)
# 仅对于第一个json文件记录选择的点索引
if idx == 0:
select_point_index = np.where(labels == 1)[0]
if idx == 0:
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
if idx in [1, 2]:
labels_v = labels[VV]
points_v = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1
mask_pd = np.array(mask_pd, np.uint8)
else:
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=colors[idx][i])
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
if __name__ == '__main__':
# 示例调用
try:
folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process'
img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg'
json_folders = [
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
''
]
#plot2(folder, folder)
plot_combine(img_folder, json_folders[0], json_folders[1], json_folders[2], json_folders[3])
#plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view')
except Exception as e:
import traceback
print(f"Error processing error: {e}")
traceback.print_exc()
# 指定图像和JSON文件的文件夹路径
# img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
# json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'

View File

@ -0,0 +1,199 @@
import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
def plot1(json_path, psize=24, img_size=2048):
bg = np.zeros((img_size, img_size)) + 255
bg[0, 0] = 0
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
mask_pd = np.zeros((img_size, img_size))
mask_pd[points[:, 0], points[:, 1]] = 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(bg, cmap='gray')
h, w = np.where(mask_pd != 0)
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
#
for idx, (s, e) in enumerate(edge_index.T):
s = points[s]
e = points[e]
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
# c = ['red', 'green', 'yellow']
# c = [ '#dfc8a0', '#84aad0','#93ba66']
# c = [ '#dfc8a0','#debd97', '#84aad0']
# c = ['#9BB6CF', '#9BB6CF', '#9BB6CF']
# c = ['#83a2c1', '#9BB6CF', '#9BB6CF']
# points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25)
c = ['#dfc8a0', '#84aad0', '#93ba66']
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31)
# mask_pd = np.zeros((2048, 2048))
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
# bg = np.zeros_like(img) + 255
# bg[0, 0] = 0
# plt.imshow(bg, cmap='gray')
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
# for i in range(3):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=psize, c=c[i])
#
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
# plt.show()
# xj_points = points
# c = ['#84aad0', '#92b964', '#84aad0']
# #
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
# c = [ '#84aad0','#92b964','#92b964']
#debd97 黄 92b964 绿 84aad0 蓝
# c = [ '#92b964','#debd97','#debd97']
c = [ '#debd97','#84aad0','#debd97']
points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
AA = np.where(labels == 0)[0]
VV = np.where(labels == 1)[0]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
# plt.show()
c = ['#84aad0', '#92b964', '#ff5b5b']
points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28)
VV = np.where(labels == 1)[0]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
# a = []
# b = []
#
# for i in points:
# if i not in xj_points:
# print(i)
VV = np.where(labels == 1)[0]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
#
#
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('49.jpg', dpi=300)
plt.show()
# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json'
# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg'
# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json'
# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg'
# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json'
# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg'
# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json'
# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg'
json_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.json'
img_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.jpg'
json_path2 = '/home/gao/mouclear/cc/final_todo/fig_50/jj_result/raw/49.json'
json_path3 = '/home/gao/mouclear/cc/final_todo/fig_50/cz_result/raw/49.json'
json_path4 = '/home/gao/mouclear/cc/final_todo/kw_49_v2/49.json'
plot2(img_path1,json_path1, json_path2, json_path3, json_path4)

View File

@ -0,0 +1,325 @@
import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
def plot1(json_path, psize=24, img_size=2048):
bg = np.zeros((img_size, img_size)) + 255
bg[0, 0] = 0
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
mask_pd = np.zeros((img_size, img_size))
mask_pd[points[:, 0], points[:, 1]] = 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
plt.imshow(bg, cmap='gray')
h, w = np.where(mask_pd != 0)
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
#
for idx, (s, e) in enumerate(edge_index.T):
s = points[s]
e = points[e]
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
c = ['#dfc8a0', '#814d81', '#debd97']
##814d81紫色
points, edge_index, labels, _ = load_data_v2(json_path)
select_point_index = np.where(labels == 1)[0]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = [ '#debd97','#84aad0','#debd97']
points, edge_index, labels, _ = load_data_v2(json_path2)
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2(json_path3)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#92b964', '#ff5b5b']
points, edge_index, labels, _ = load_data_v2(json_path4)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
# c = ['red', 'green', 'yellow']
# c = [ '#dfc8a0', '#84aad0','#93ba66']
# c = [ '#dfc8a0','#debd97', '#84aad0']
# c = ['#9BB6CF', '#9BB6CF', '#9BB6CF']
# c = ['#83a2c1', '#9BB6CF', '#9BB6CF']
# points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25)
# c = ['#dfc8a0', '#84aad0', '#93ba66']
# c = ['#dfc8a0', '#814d81', '#debd97']
c = ['#dfc8a0', '#814d81', '#debd97']
##814d81紫色
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
select_point_index = np.where(labels == 1)[0]
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31)
# mask_pd = np.zeros((2048, 2048))
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
img = np.array(Image.open(img_path))
# bg = np.zeros_like(img) + 255
# bg[0, 0] = 0
# plt.imshow(bg, cmap='gray')
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
# plt.show()
# #
# for idx, (s, e) in enumerate(edge_index.T):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
# c = [ '#84aad0','#92b964','#92b964']
#debd97 黄 92b964 绿 84aad0 蓝
# c = [ '#92b964','#debd97','#debd97']
c = [ '#debd97','#84aad0','#debd97']
points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
AA = np.where(labels == 0)[0]
AA = AA[~np.isin(AA, select_point_index)]
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labelsv = labels[VV]
pointsv = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# VV = np.where(labels == 2)[0]
labels = labels[AA]
points = points[AA]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
#
# plt.savefig('cz.jpg', dpi=300)
# plt.show()
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
#
# plt.savefig('JJ.jpg', dpi=300)
# plt.show()
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
# plt.show()
c = ['#84aad0', '#92b964', '#ff5b5b']
points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28)
VV = np.where(labels == 1)[0]
VV = VV[~np.isin(VV, select_point_index)]
labels = labels[VV]
points = points[VV]
mask_pd = np.zeros((512, 512)) + 256
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
for i in range(3):
h, w = np.where(mask_pd == i + 1)
plt.scatter(w, h, s=psize, c=c[i])
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.tight_layout()
plt.savefig('kw.jpg', dpi=300)
plt.show()
#
#
# c = ['#84aad0', '#ff5b5b', '#ff5b5b']
# # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
# points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
#
# # a = []
# # b = []
# #
# # for i in points:
# # if i not in xj_points:
# # print(i)
# VV = np.where(labels == 1)[0]
# labels = labels[VV]
# points = points[VV]
#
# mask_pd = np.zeros((512, 512)) + 256
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
# mask_pd = np.array(mask_pd, np.uint8)
# for i in range(3):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=psize, c=c[i])
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
#
#
#
# plt.imshow(img, cmap='gray')
# plt.axis('off')
# plt.tight_layout()
#
# plt.savefig('49.jpg', dpi=300)
# plt.show()
# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json'
# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg'
# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json'
# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg'
# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json'
# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg'
# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json'
# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg'
img_path1 = '/home/gao/mouclear/cc/final_todo/53/cz/raw/53.jpg'
json_path1 = '/home/gao/mouclear/cc/final_todo/53/xj_norm/raw/53.json'
json_path2 = '/home/gao/mouclear/cc/final_todo/53/jj_nonorm_WO3/raw/53.json'
json_path3 = '/home/gao/mouclear/cc/final_todo/53/cz_mask/raw/mask.json'
# json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/53.json'
json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/mask.json'
plot2(img_path1,json_path1, json_path2, json_path3, json_path4)

View File

@ -27,7 +27,8 @@ from torch_geometric.data.lightning import LightningDataset
from torch.utils.data import ConcatDataset
from egnn_utils.save import save_results
from resize import resize_images
from metricse2e_vor import post_process, norm_line_sv_label, cz_label
from metricse2e_vor import post_process
from model_type_dict import norm_line_sv_label, cz_label
from plot_view import plot2, plot1
from dp.launching.report import Report, ReportSection, AutoReportElement
# constants
@ -219,14 +220,14 @@ def predict(model_name, test_path, save_path, edges_length=35.5, model_type=None
return save_predict_result(predictions, save_path)
def create_save_path(base_path: str) -> Dict[str, str]:
def create_save_path(base_path: str, model_type: str) -> Dict[str, str]:
# 定义需要创建的目录结构
paths = {
"gnn_predict_dataset": os.path.join(base_path, "gnn_predict_dataset"),
"gnn_predict_result" : os.path.join(base_path, 'gnn_predict_json'),
"gnn_predict_post_process": os.path.join(base_path, "gnn_predict_post_process"),
"gnn_predict_connect_view": os.path.join(base_path, "gnn_predict_connect_view"),
"gnn_predict_result_view": os.path.join(base_path, "gnn_predict_result_view"),
"gnn_predict_dataset": os.path.join(base_path, model_type,"gnn_predict_dataset"),
"gnn_predict_result" : os.path.join(base_path, model_type, 'gnn_predict_json'),
"gnn_predict_post_process": os.path.join(base_path, model_type, "gnn_predict_post_process"),
"gnn_predict_connect_view": os.path.join(base_path, model_type, "gnn_predict_connect_view"),
"gnn_predict_result_view": os.path.join(base_path, model_type, "gnn_predict_result_view"),
}
# 创建目录
@ -288,7 +289,7 @@ def plot_connect_line(json_folder: str, output_folder: str):
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5, model_type=norm_line_sv_label):
print("gnn model: ", model_name)
#创建保存路径
save_path = create_save_path(save_base_path)
save_path = create_save_path(save_base_path, model_type)
# 数据预处理, resize图片
resize_images(512,256,test_path)
#预测结果

View File

@ -32,7 +32,8 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
from metricse2e_vor import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
from model_type_dict import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
from plot_view import plot_multiple_annotations
class PredictOptions(BaseModel):
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")

View File

@ -1,10 +1,16 @@
{
"data_path": "sing.zip",
"data_path": "combine.zip",
"edges_length": 28,
"output_dir": "./result_norm_new",
"output_dir": "./result_combine_4",
"select_models": [
{
"model_name": "点线缺陷分类"
"model_name": "012三种不同结构的原子类型分类"
},
{
"model_name": "正常掺杂分类"
},
{
"model_name": "SMoV三种原子分类"
}
]
}

View File

@ -2,6 +2,7 @@ import argparse
import os
import json
import glob
import shutil
import zipfile
import torch
@ -203,8 +204,17 @@ def train(process_path: str, model_save_path: str)->str:
valid_loader
)
model_name = os.path.join(model_save_path, 'last.ckpt')
trainer.save_checkpoint(model_name)
# model_name = os.path.join(model_save_path, 'last.ckpt')
# trainer.save_checkpoint(model_name)
last_checkpoint_path = checkpoint_callback.last_model_path
if last_checkpoint_path:
model_name = os.path.join(model_save_path, 'last.ckpt')
# Copy the last checkpoint to the desired path
shutil.copy(last_checkpoint_path, model_name)
else:
print("No checkpoint was saved by the checkpoint callback.")
return model_name
# # inference
@ -247,23 +257,29 @@ def create_save_path(base_path: str) -> Dict[str, str]:
def arg_parse()->argparse.Namespace:
# # 增加参数一个是数据集的路径,另外一个是保存的路径
parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--data_path', type=str, default='', help='path to dataset')
parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result')
parser.add_argument('--dataset', default="", help='path to test dataset')
parser.add_argument('--model_output', default="", help='path to model result path')
# parser.add_argument('--data_path', type=str, default='', help='path to dataset')
# parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = arg_parse()
#创建目录
save_path = create_save_path(args.save_path)
save_path = create_save_path(args.model_output)
print("save path: ", save_path)
original_path = save_path['train_original']
train_pre_process_path = save_path['train_pre_process']
model_save_path = save_path['model_save_path']
#从dataset目录获取数据集名称
for file in os.listdir(args.dataset):
if file.endswith('.zip'):
zip_file_path = os.path.join(args.dataset, file)
# 解压数据集到original_path
with zipfile.ZipFile(args.data_path, 'r') as zip_ref:
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(original_path)
#image_path = './train_and_test'
@ -273,10 +289,10 @@ if __name__ == '__main__':
model_path = train(train_pre_process_path, model_save_path)
print("test start")
print("test start")
test_path= os.path.join(original_path, 'test')
predict_and_plot(model_path, test_path, args.save_path, True)
print("test end")
# print("test start")
# test_path= os.path.join(original_path, 'test')
# predict_and_plot(model_path, test_path, args.save_path, True)
# print("test end")
# if __name__ == '__main__':