组合模式画图
This commit is contained in:
parent
a9fd74feb9
commit
02009a8e52
|
@ -9,7 +9,7 @@ from PIL import Image
|
|||
from torch_geometric.data import Data
|
||||
from torch_geometric.data import InMemoryDataset, download_url, Dataset
|
||||
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
|
||||
#from metricse2e_vor import norm_line_sv_label
|
||||
from model_type_dict import norm_line_sv_label
|
||||
import albumentations as A
|
||||
|
||||
try:
|
||||
|
@ -27,7 +27,7 @@ def get_training_augmentation():
|
|||
return A.Compose(train_transform)
|
||||
|
||||
def get_validation_augmentation(model_type=None):
|
||||
if model_type == "点线缺陷分类":
|
||||
if model_type == norm_line_sv_label:
|
||||
print("线缺陷检测模型需要归一化")
|
||||
test_transform = [
|
||||
A.Normalize(),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch.nn as nn
|
||||
from egnn_core.egnn_clean import EGNN
|
||||
from metricse2e_vor import cz_label
|
||||
from model_type_dict import cz_label
|
||||
|
||||
class PL_EGNN(nn.Module):
|
||||
def __init__(self, model_type=None):
|
||||
|
|
|
@ -1,66 +1,15 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from egnn_utils.e2e_metrics import get_metrics
|
||||
from egnn_core.data import get_y_3
|
||||
from egnn_core.data import load_data
|
||||
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
|
||||
|
||||
|
||||
class_dict = {
|
||||
0: 'Norm',
|
||||
1: 'SV',
|
||||
2: 'LineSV',
|
||||
}
|
||||
|
||||
class_dict_rev = {
|
||||
'Norm': 1,
|
||||
'SV': 2,
|
||||
'LineSV': 3,
|
||||
}
|
||||
|
||||
#cz:
|
||||
cz_class_dict = {
|
||||
0: 'atom',
|
||||
1: 'cz',
|
||||
}
|
||||
|
||||
#xj:
|
||||
xj_class_dict = {
|
||||
0: '0',
|
||||
1: '1',
|
||||
2: '2',
|
||||
}
|
||||
|
||||
#smov:
|
||||
smov_class_dict = {
|
||||
0: 'S',
|
||||
1: 'Mo',
|
||||
2: 'V'
|
||||
}
|
||||
|
||||
norm_line_sv_label = "点线缺陷分类"
|
||||
cz_label = "正常掺杂分类"
|
||||
xj_label = "012三种不同结构的原子类型分类"
|
||||
smov_label = "S/Mo/V三种原子分类"
|
||||
|
||||
model_path_dict = {
|
||||
norm_line_sv_label: "model/last.ckpt",
|
||||
cz_label: "model/cz_new.ckpt",
|
||||
xj_label: "model/xj_new_2.ckpt",
|
||||
smov_label: "model/smov.ckpt"
|
||||
|
||||
}
|
||||
|
||||
from model_type_dict import class_dict, cz_class_dict, xj_class_dict, smov_class_dict, norm_line_sv_label, \
|
||||
cz_label, xj_label, smov_label
|
||||
|
||||
def get_class_dict(model_type):
|
||||
if model_type == cz_label:
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
|
||||
|
||||
class_dict = {
|
||||
0: 'Norm',
|
||||
1: 'SV',
|
||||
2: 'LineSV',
|
||||
}
|
||||
class_dict_rev = {
|
||||
'Norm': 1,
|
||||
'SV': 2,
|
||||
'LineSV': 3,
|
||||
}
|
||||
cz_class_dict = {
|
||||
0: 'atom',
|
||||
1: 'cz',
|
||||
}
|
||||
xj_class_dict = {
|
||||
0: '0',
|
||||
1: '1',
|
||||
2: '2',
|
||||
}
|
||||
smov_class_dict = {
|
||||
0: 'S',
|
||||
1: 'Mo',
|
||||
2: 'V'
|
||||
}
|
||||
norm_line_sv_label = "点线缺陷分类"
|
||||
cz_label = "正常掺杂分类"
|
||||
xj_label = "012三种不同结构的原子类型分类"
|
||||
smov_label = "SMoV三种原子分类"
|
||||
model_path_dict = {
|
||||
norm_line_sv_label: "model/last.ckpt",
|
||||
cz_label: "model/cz_new.ckpt",
|
||||
xj_label: "model/xj_new_2.ckpt",
|
||||
smov_label: "model/smov.ckpt"
|
||||
|
||||
}
|
|
@ -88,6 +88,217 @@ def plot2(img_folder, json_folder, psize=26, output_folder='output'):
|
|||
plt.savefig(output_path, dpi = 300)
|
||||
plt.close() # 关闭当前的绘图窗口,以避免内存泄漏
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import json
|
||||
|
||||
# 假设你有一个用于加载JSON数据的函数
|
||||
|
||||
def plot_multiple_annotations(img_folder, json_folders, psize=26, output_folder='output'):
|
||||
# 定义颜色数组
|
||||
colors = ['#9BB6CF', '#76F1A2', '#EDC08C', 'red']
|
||||
model_colors = ['#FF0000', '#00FF00', '#0000FF', '#FFFF00', '#FF00FF', '#00FFFF']
|
||||
|
||||
# 确保输出文件夹存在
|
||||
if not os.path.exists(output_folder):
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
|
||||
# 遍历指定文件夹中的所有文件
|
||||
for img_filename in os.listdir(img_folder):
|
||||
|
||||
if img_filename.endswith(".jpg"): # 确保是jpeg文件
|
||||
img_path = os.path.join(img_folder, img_filename)
|
||||
img = cv2.imread(img_path, 0)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
plt.imshow(img, cmap='gray')
|
||||
|
||||
for model_idx, json_folder in enumerate(json_folders):
|
||||
json_filename = img_filename.replace('.jpg', '.json')
|
||||
print(f"Processing {json_filename}")
|
||||
json_path = os.path.join(json_folder, json_filename)
|
||||
|
||||
if os.path.exists(json_path):
|
||||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||||
mask_pd = np.zeros(img.shape, np.uint8)
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
|
||||
for i in range(4):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=model_colors[model_idx], label=f'Model {model_idx+1}, Class {i+1}', alpha=0.5)
|
||||
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图像
|
||||
output_path = os.path.join(output_folder, img_filename)
|
||||
print(f"Save plot to {output_path}")
|
||||
plt.savefig(output_path, dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
|
||||
c = ['#dfc8a0', '#814d81', '#debd97']
|
||||
##814d81紫色
|
||||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||||
img = cv2.imread(img_path, 0)
|
||||
select_point_index = np.where(labels == 1)[0]
|
||||
|
||||
mask_pd = np.zeros(img.shape, np.uint8)
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
#img = np.array(Image.open(img_path))
|
||||
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
c = ['#debd97', '#84aad0', '#debd97']
|
||||
|
||||
points, edge_index, labels, _ = load_data_v2(json_path2)
|
||||
|
||||
AA = np.where(labels == 0)[0]
|
||||
AA = AA[~np.isin(AA, select_point_index)]
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
labelsv = labels[VV]
|
||||
pointsv = points[VV]
|
||||
mask_pd = np.zeros(img.shape, np.uint8)
|
||||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# VV = np.where(labels == 2)[0]
|
||||
labels = labels[AA]
|
||||
points = points[AA]
|
||||
mask_pd = np.zeros(img.shape, np.uint8)
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
points, edge_index, labels, _ = load_data_v2(json_path3)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
|
||||
mask_pd = np.zeros(img.shape, np.uint8)
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# c = ['#84aad0', '#92b964', '#ff5b5b']
|
||||
# points, edge_index, labels, _ = load_data_v2(json_path4)
|
||||
#
|
||||
# VV = np.where(labels == 1)[0]
|
||||
# VV = VV[~np.isin(VV, select_point_index)]
|
||||
#
|
||||
# labels = labels[VV]
|
||||
# points = points[VV]
|
||||
# mask_pd = np.zeros((512, 512)) + 256
|
||||
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
# mask_pd = np.array(mask_pd, np.uint8)
|
||||
# for i in range(3):
|
||||
# h, w = np.where(mask_pd == i + 1)
|
||||
# plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('kw.jpg', dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
def plot_combine1(img_path, json_paths, psize=40):
|
||||
colors = [
|
||||
['#dfc8a0', '#814d81', '#debd97'],
|
||||
['#debd97', '#84aad0', '#debd97'],
|
||||
['#84aad0', '#ff5b5b', '#ff5b5b'],
|
||||
['#84aad0', '#92b964', '#ff5b5b']
|
||||
]
|
||||
|
||||
select_point_index = []
|
||||
|
||||
# 初始化绘图
|
||||
plt.figure(figsize=(9, 9))
|
||||
img = np.array(Image.open(img_path))
|
||||
plt.imshow(img, cmap='gray')
|
||||
|
||||
for idx, json_path in enumerate(json_paths):
|
||||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||||
|
||||
# 仅对于第一个json文件,记录选择的点索引
|
||||
if idx == 0:
|
||||
select_point_index = np.where(labels == 1)[0]
|
||||
|
||||
if idx == 0:
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
else:
|
||||
AA = np.where(labels == 0)[0]
|
||||
AA = AA[~np.isin(AA, select_point_index)]
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
if idx in [1, 2]:
|
||||
labels_v = labels[VV]
|
||||
points_v = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points_v[:, 0], points_v[:, 1]] = labels_v + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
else:
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=colors[idx][i])
|
||||
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('kw.jpg', dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 示例调用
|
||||
try:
|
||||
folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process'
|
||||
img_folder = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.jpg'
|
||||
json_folders = [
|
||||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/正常掺杂分类/gnn_predict_post_process/49.json',
|
||||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/SMoV三种原子分类/gnn_predict_post_process/49.json',
|
||||
'/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/012三种不同结构的原子类型分类/gnn_predict_post_process/49.json',
|
||||
''
|
||||
]
|
||||
#plot2(folder, folder)
|
||||
plot_combine(img_folder, json_folders[0], json_folders[1], json_folders[2], json_folders[3])
|
||||
#plot_multiple_annotations(img_folder, json_folders, output_folder='/home/deploy/script/bohrim-app/bohrium-app/pythonProject/msunet/result_combine_3/combine_view')
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"Error processing error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 指定图像和JSON文件的文件夹路径
|
||||
# img_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
|
||||
# json_folder = '/home/gao/mouclear/cc/data/all_sv_e2e/sv'
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw
|
||||
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
|
||||
|
||||
|
||||
def plot1(json_path, psize=24, img_size=2048):
|
||||
bg = np.zeros((img_size, img_size)) + 255
|
||||
bg[0, 0] = 0
|
||||
|
||||
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
|
||||
|
||||
mask_pd = np.zeros((img_size, img_size))
|
||||
mask_pd[points[:, 0], points[:, 1]] = 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
plt.imshow(bg, cmap='gray')
|
||||
h, w = np.where(mask_pd != 0)
|
||||
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
|
||||
#
|
||||
for idx, (s, e) in enumerate(edge_index.T):
|
||||
s = points[s]
|
||||
e = points[e]
|
||||
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
plt.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
|
||||
# c = ['red', 'green', 'yellow']
|
||||
# c = [ '#dfc8a0', '#84aad0','#93ba66']
|
||||
|
||||
# c = [ '#dfc8a0','#debd97', '#84aad0']
|
||||
|
||||
# c = ['#9BB6CF', '#9BB6CF', '#9BB6CF']
|
||||
# c = ['#83a2c1', '#9BB6CF', '#9BB6CF']
|
||||
|
||||
# points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25)
|
||||
c = ['#dfc8a0', '#84aad0', '#93ba66']
|
||||
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
|
||||
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31)
|
||||
|
||||
# mask_pd = np.zeros((2048, 2048))
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
|
||||
img = np.array(Image.open(img_path))
|
||||
# bg = np.zeros_like(img) + 255
|
||||
# bg[0, 0] = 0
|
||||
# plt.imshow(bg, cmap='gray')
|
||||
|
||||
|
||||
|
||||
# for idx, (s, e) in enumerate(edge_index.T):
|
||||
# s = points[s]
|
||||
# e = points[e]
|
||||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
# for i in range(3):
|
||||
# h, w = np.where(mask_pd == i + 1)
|
||||
# plt.scatter(w, h, s=psize, c=c[i])
|
||||
#
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
# xj_points = points
|
||||
|
||||
# c = ['#84aad0', '#92b964', '#84aad0']
|
||||
|
||||
|
||||
# #
|
||||
# for idx, (s, e) in enumerate(edge_index.T):
|
||||
# s = points[s]
|
||||
# e = points[e]
|
||||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
|
||||
|
||||
|
||||
# c = [ '#84aad0','#92b964','#92b964']
|
||||
#debd97 黄 92b964 绿 84aad0 蓝
|
||||
# c = [ '#92b964','#debd97','#debd97']
|
||||
c = [ '#debd97','#84aad0','#debd97']
|
||||
points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
|
||||
AA = np.where(labels == 0)[0]
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
labelsv = labels[VV]
|
||||
pointsv = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# VV = np.where(labels == 2)[0]
|
||||
labels = labels[AA]
|
||||
points = points[AA]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
c = ['#84aad0', '#92b964', '#ff5b5b']
|
||||
points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
|
||||
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
|
||||
|
||||
# a = []
|
||||
# b = []
|
||||
#
|
||||
# for i in points:
|
||||
# if i not in xj_points:
|
||||
# print(i)
|
||||
VV = np.where(labels == 1)[0]
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
|
||||
plt.savefig('49.jpg', dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json'
|
||||
# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg'
|
||||
# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json'
|
||||
# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg'
|
||||
# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json'
|
||||
# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg'
|
||||
# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json'
|
||||
# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg'
|
||||
|
||||
json_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.json'
|
||||
img_path1 = '/home/gao/mouclear/cc/final_todo/fig_50/xj_result/raw/49.jpg'
|
||||
json_path2 = '/home/gao/mouclear/cc/final_todo/fig_50/jj_result/raw/49.json'
|
||||
json_path3 = '/home/gao/mouclear/cc/final_todo/fig_50/cz_result/raw/49.json'
|
||||
json_path4 = '/home/gao/mouclear/cc/final_todo/kw_49_v2/49.json'
|
||||
|
||||
|
||||
|
||||
plot2(img_path1,json_path1, json_path2, json_path3, json_path4)
|
|
@ -0,0 +1,325 @@
|
|||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
from core.data import load_data_v2_plot,load_data_v2_plot_jj,load_data_v2_plot_cz,load_data_v2_plot_kw
|
||||
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
|
||||
|
||||
|
||||
def plot1(json_path, psize=24, img_size=2048):
|
||||
bg = np.zeros((img_size, img_size)) + 255
|
||||
bg[0, 0] = 0
|
||||
|
||||
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
|
||||
|
||||
mask_pd = np.zeros((img_size, img_size))
|
||||
mask_pd[points[:, 0], points[:, 1]] = 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
plt.imshow(bg, cmap='gray')
|
||||
h, w = np.where(mask_pd != 0)
|
||||
plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
|
||||
#
|
||||
for idx, (s, e) in enumerate(edge_index.T):
|
||||
s = points[s]
|
||||
e = points[e]
|
||||
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
plt.axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
def plot_combine(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
|
||||
c = ['#dfc8a0', '#814d81', '#debd97']
|
||||
##814d81紫色
|
||||
points, edge_index, labels, _ = load_data_v2(json_path)
|
||||
|
||||
select_point_index = np.where(labels == 1)[0]
|
||||
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
img = np.array(Image.open(img_path))
|
||||
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
c = [ '#debd97','#84aad0','#debd97']
|
||||
|
||||
points, edge_index, labels, _ = load_data_v2(json_path2)
|
||||
|
||||
AA = np.where(labels == 0)[0]
|
||||
AA = AA[~np.isin(AA, select_point_index)]
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
|
||||
labelsv = labels[VV]
|
||||
pointsv = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# VV = np.where(labels == 2)[0]
|
||||
labels = labels[AA]
|
||||
points = points[AA]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
|
||||
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
points, edge_index, labels, _ = load_data_v2(json_path3)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
c = ['#84aad0', '#92b964', '#ff5b5b']
|
||||
points, edge_index, labels, _ = load_data_v2(json_path4)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('kw.jpg', dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
def plot2(img_path, json_path, json_path2, json_path3, json_path4, psize=40):
|
||||
# c = ['red', 'green', 'yellow']
|
||||
# c = [ '#dfc8a0', '#84aad0','#93ba66']
|
||||
|
||||
# c = [ '#dfc8a0','#debd97', '#84aad0']
|
||||
|
||||
# c = ['#9BB6CF', '#9BB6CF', '#9BB6CF']
|
||||
# c = ['#83a2c1', '#9BB6CF', '#9BB6CF']
|
||||
|
||||
# points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=25)
|
||||
# c = ['#dfc8a0', '#84aad0', '#93ba66']
|
||||
# c = ['#dfc8a0', '#814d81', '#debd97']
|
||||
c = ['#dfc8a0', '#814d81', '#debd97']
|
||||
##814d81紫色
|
||||
points, edge_index, labels, _ = load_data_v2_plot(json_path,max_edge_length=28)
|
||||
|
||||
select_point_index = np.where(labels == 1)[0]
|
||||
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path,max_edge_length=31)
|
||||
|
||||
# mask_pd = np.zeros((2048, 2048))
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
|
||||
plt.figure(figsize=(9, 9))
|
||||
img = np.array(Image.open(img_path))
|
||||
# bg = np.zeros_like(img) + 255
|
||||
# bg[0, 0] = 0
|
||||
# plt.imshow(bg, cmap='gray')
|
||||
|
||||
# for idx, (s, e) in enumerate(edge_index.T):
|
||||
# s = points[s]
|
||||
# e = points[e]
|
||||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
|
||||
|
||||
|
||||
# #
|
||||
# for idx, (s, e) in enumerate(edge_index.T):
|
||||
# s = points[s]
|
||||
# e = points[e]
|
||||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||||
|
||||
# c = [ '#84aad0','#92b964','#92b964']
|
||||
#debd97 黄 92b964 绿 84aad0 蓝
|
||||
# c = [ '#92b964','#debd97','#debd97']
|
||||
c = [ '#debd97','#84aad0','#debd97']
|
||||
|
||||
points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
|
||||
AA = np.where(labels == 0)[0]
|
||||
AA = AA[~np.isin(AA, select_point_index)]
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
|
||||
labelsv = labels[VV]
|
||||
pointsv = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[pointsv[:, 0], pointsv[:, 1]] = labelsv + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# VV = np.where(labels == 2)[0]
|
||||
labels = labels[AA]
|
||||
points = points[AA]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
|
||||
c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
#
|
||||
# plt.savefig('cz.jpg', dpi=300)
|
||||
# plt.show()
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
#
|
||||
# plt.savefig('JJ.jpg', dpi=300)
|
||||
# plt.show()
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
# plt.show()
|
||||
c = ['#84aad0', '#92b964', '#ff5b5b']
|
||||
points, edge_index, labels, _ = load_data_v2_plot_kw(json_path4, max_edge_length=28)
|
||||
|
||||
VV = np.where(labels == 1)[0]
|
||||
VV = VV[~np.isin(VV, select_point_index)]
|
||||
|
||||
labels = labels[VV]
|
||||
points = points[VV]
|
||||
mask_pd = np.zeros((512, 512)) + 256
|
||||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
mask_pd = np.array(mask_pd, np.uint8)
|
||||
for i in range(3):
|
||||
h, w = np.where(mask_pd == i + 1)
|
||||
plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
plt.imshow(img, cmap='gray')
|
||||
plt.axis('off')
|
||||
plt.tight_layout()
|
||||
plt.savefig('kw.jpg', dpi=300)
|
||||
plt.show()
|
||||
#
|
||||
#
|
||||
# c = ['#84aad0', '#ff5b5b', '#ff5b5b']
|
||||
# # points, edge_index, labels, _ = load_data_v2_plot_jj(json_path2, max_edge_length=28)
|
||||
# points, edge_index, labels, _ = load_data_v2_plot_cz(json_path3, max_edge_length=28)
|
||||
#
|
||||
# # a = []
|
||||
# # b = []
|
||||
# #
|
||||
# # for i in points:
|
||||
# # if i not in xj_points:
|
||||
# # print(i)
|
||||
# VV = np.where(labels == 1)[0]
|
||||
# labels = labels[VV]
|
||||
# points = points[VV]
|
||||
#
|
||||
# mask_pd = np.zeros((512, 512)) + 256
|
||||
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||||
# mask_pd = np.array(mask_pd, np.uint8)
|
||||
# for i in range(3):
|
||||
# h, w = np.where(mask_pd == i + 1)
|
||||
# plt.scatter(w, h, s=psize, c=c[i])
|
||||
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
|
||||
|
||||
|
||||
#
|
||||
#
|
||||
#
|
||||
# plt.imshow(img, cmap='gray')
|
||||
# plt.axis('off')
|
||||
# plt.tight_layout()
|
||||
#
|
||||
# plt.savefig('49.jpg', dpi=300)
|
||||
# plt.show()
|
||||
|
||||
|
||||
# json_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.json'
|
||||
# img_path4 = '/home/gao/mouclear/detect_40/40/39/jj_v2/after_test_v2/39.jpg'
|
||||
# json_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.json'
|
||||
# img_path1 = '/home/gao/mouclear/detect_40/40/39/cz/v1/39.jpg'
|
||||
# json_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.json'
|
||||
# img_path3 = '/home/gao/mouclear/detect_40/40/39/smo/raw/39.jpg'
|
||||
# json_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.json'
|
||||
# img_path2 = '/home/gao/mouclear/detect_40/40/39/xj_28/v2/39.jpg'
|
||||
img_path1 = '/home/gao/mouclear/cc/final_todo/53/cz/raw/53.jpg'
|
||||
json_path1 = '/home/gao/mouclear/cc/final_todo/53/xj_norm/raw/53.json'
|
||||
json_path2 = '/home/gao/mouclear/cc/final_todo/53/jj_nonorm_WO3/raw/53.json'
|
||||
json_path3 = '/home/gao/mouclear/cc/final_todo/53/cz_mask/raw/mask.json'
|
||||
# json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/53.json'
|
||||
json_path4 = '/home/gao/mouclear/cc/final_todo/53/kw/mask.json'
|
||||
|
||||
|
||||
|
||||
plot2(img_path1,json_path1, json_path2, json_path3, json_path4)
|
|
@ -27,7 +27,8 @@ from torch_geometric.data.lightning import LightningDataset
|
|||
from torch.utils.data import ConcatDataset
|
||||
from egnn_utils.save import save_results
|
||||
from resize import resize_images
|
||||
from metricse2e_vor import post_process, norm_line_sv_label, cz_label
|
||||
from metricse2e_vor import post_process
|
||||
from model_type_dict import norm_line_sv_label, cz_label
|
||||
from plot_view import plot2, plot1
|
||||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||||
# constants
|
||||
|
@ -219,14 +220,14 @@ def predict(model_name, test_path, save_path, edges_length=35.5, model_type=None
|
|||
|
||||
return save_predict_result(predictions, save_path)
|
||||
|
||||
def create_save_path(base_path: str) -> Dict[str, str]:
|
||||
def create_save_path(base_path: str, model_type: str) -> Dict[str, str]:
|
||||
# 定义需要创建的目录结构
|
||||
paths = {
|
||||
"gnn_predict_dataset": os.path.join(base_path, "gnn_predict_dataset"),
|
||||
"gnn_predict_result" : os.path.join(base_path, 'gnn_predict_json'),
|
||||
"gnn_predict_post_process": os.path.join(base_path, "gnn_predict_post_process"),
|
||||
"gnn_predict_connect_view": os.path.join(base_path, "gnn_predict_connect_view"),
|
||||
"gnn_predict_result_view": os.path.join(base_path, "gnn_predict_result_view"),
|
||||
"gnn_predict_dataset": os.path.join(base_path, model_type,"gnn_predict_dataset"),
|
||||
"gnn_predict_result" : os.path.join(base_path, model_type, 'gnn_predict_json'),
|
||||
"gnn_predict_post_process": os.path.join(base_path, model_type, "gnn_predict_post_process"),
|
||||
"gnn_predict_connect_view": os.path.join(base_path, model_type, "gnn_predict_connect_view"),
|
||||
"gnn_predict_result_view": os.path.join(base_path, model_type, "gnn_predict_result_view"),
|
||||
}
|
||||
|
||||
# 创建目录
|
||||
|
@ -288,7 +289,7 @@ def plot_connect_line(json_folder: str, output_folder: str):
|
|||
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5, model_type=norm_line_sv_label):
|
||||
print("gnn model: ", model_name)
|
||||
#创建保存路径
|
||||
save_path = create_save_path(save_base_path)
|
||||
save_path = create_save_path(save_base_path, model_type)
|
||||
# 数据预处理, resize图片
|
||||
resize_images(512,256,test_path)
|
||||
#预测结果
|
||||
|
|
|
@ -32,7 +32,8 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn
|
|||
print(sys.path)
|
||||
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
|
||||
from train_pl_v2_aug import train as egnn_train
|
||||
from metricse2e_vor import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
|
||||
from model_type_dict import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
|
||||
from plot_view import plot_multiple_annotations
|
||||
|
||||
class PredictOptions(BaseModel):
|
||||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
{
|
||||
"data_path": "sing.zip",
|
||||
"data_path": "combine.zip",
|
||||
"edges_length": 28,
|
||||
"output_dir": "./result_norm_new",
|
||||
"output_dir": "./result_combine_4",
|
||||
"select_models": [
|
||||
{
|
||||
"model_name": "点线缺陷分类"
|
||||
"model_name": "012三种不同结构的原子类型分类"
|
||||
},
|
||||
{
|
||||
"model_name": "正常掺杂分类"
|
||||
},
|
||||
{
|
||||
"model_name": "SMoV三种原子分类"
|
||||
}
|
||||
]
|
||||
}
|
|
@ -2,6 +2,7 @@ import argparse
|
|||
import os
|
||||
import json
|
||||
import glob
|
||||
import shutil
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
|
@ -203,8 +204,17 @@ def train(process_path: str, model_save_path: str)->str:
|
|||
valid_loader
|
||||
)
|
||||
|
||||
model_name = os.path.join(model_save_path, 'last.ckpt')
|
||||
trainer.save_checkpoint(model_name)
|
||||
# model_name = os.path.join(model_save_path, 'last.ckpt')
|
||||
# trainer.save_checkpoint(model_name)
|
||||
|
||||
last_checkpoint_path = checkpoint_callback.last_model_path
|
||||
|
||||
if last_checkpoint_path:
|
||||
model_name = os.path.join(model_save_path, 'last.ckpt')
|
||||
# Copy the last checkpoint to the desired path
|
||||
shutil.copy(last_checkpoint_path, model_name)
|
||||
else:
|
||||
print("No checkpoint was saved by the checkpoint callback.")
|
||||
|
||||
return model_name
|
||||
# # inference
|
||||
|
@ -247,23 +257,29 @@ def create_save_path(base_path: str) -> Dict[str, str]:
|
|||
def arg_parse()->argparse.Namespace:
|
||||
# # 增加参数一个是数据集的路径,另外一个是保存的路径
|
||||
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||
parser.add_argument('--data_path', type=str, default='', help='path to dataset')
|
||||
parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result')
|
||||
parser.add_argument('--dataset', default="", help='path to test dataset')
|
||||
parser.add_argument('--model_output', default="", help='path to model result path')
|
||||
# parser.add_argument('--data_path', type=str, default='', help='path to dataset')
|
||||
# parser.add_argument('--save_path', type=str, default='./train_result', help='path to save result')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = arg_parse()
|
||||
#创建目录
|
||||
save_path = create_save_path(args.save_path)
|
||||
save_path = create_save_path(args.model_output)
|
||||
print("save path: ", save_path)
|
||||
|
||||
original_path = save_path['train_original']
|
||||
train_pre_process_path = save_path['train_pre_process']
|
||||
model_save_path = save_path['model_save_path']
|
||||
|
||||
#从dataset目录获取数据集名称
|
||||
for file in os.listdir(args.dataset):
|
||||
if file.endswith('.zip'):
|
||||
zip_file_path = os.path.join(args.dataset, file)
|
||||
# 解压数据集到original_path
|
||||
with zipfile.ZipFile(args.data_path, 'r') as zip_ref:
|
||||
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(original_path)
|
||||
|
||||
#image_path = './train_and_test'
|
||||
|
@ -273,10 +289,10 @@ if __name__ == '__main__':
|
|||
model_path = train(train_pre_process_path, model_save_path)
|
||||
print("test start")
|
||||
|
||||
print("test start")
|
||||
test_path= os.path.join(original_path, 'test')
|
||||
predict_and_plot(model_path, test_path, args.save_path, True)
|
||||
print("test end")
|
||||
# print("test start")
|
||||
# test_path= os.path.join(original_path, 'test')
|
||||
# predict_and_plot(model_path, test_path, args.save_path, True)
|
||||
# print("test end")
|
||||
|
||||
|
||||
# if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue