163 lines
5.5 KiB
Python
163 lines
5.5 KiB
Python
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from egnn_core.data import load_data_v3,load_data_v4
|
||
import os
|
||
import json
|
||
|
||
#根据连接来遍历点的表
|
||
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
||
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
||
#第三步,判断普通的原子的邻居是不是也是普通的:
|
||
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
||
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
||
#记住,此时要把该点从表中永远排除,避免重复技术
|
||
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
||
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
||
|
||
|
||
# 指定文件夹路径
|
||
folder_path = '/home/gao/mouclear/analyze_center/2'
|
||
img_size = 2048
|
||
|
||
|
||
def draw_light(img, point, light, ps=1):
|
||
img_h, img_w = img.shape
|
||
h, w = point
|
||
|
||
# 计算窗口的边界
|
||
hs = np.clip(h - ps, 0, img_h - 1)
|
||
ws = np.clip(w - ps, 0, img_w - 1)
|
||
he = np.clip(h + ps, 0, img_h - 1)
|
||
we = np.clip(w + ps, 0, img_w - 1)
|
||
|
||
# 从light中提取窗口区域的像素值
|
||
window = light.reshape((he - hs, we - ws))
|
||
|
||
# 绘制窗口到背景图像上
|
||
img[hs:he, ws:we] = window
|
||
|
||
return img
|
||
def find_adj(edge_index,adj_indexs):
|
||
matching_indices = []
|
||
for matching_index in adj_indexs:
|
||
matching_indices.append([index for index, (first, _) in enumerate(edge_index) if (first == matching_index)])
|
||
|
||
matching_indices = [item for sublist in matching_indices if sublist for item in sublist]
|
||
|
||
adj_indexs = []
|
||
for matching_index in matching_indices:
|
||
adj_indexs.append(edge_index[matching_index][1])
|
||
return adj_indexs, matching_indices
|
||
|
||
|
||
def process_json_file(json_file,img_file):
|
||
points, edge_index, labels, lights = load_data_v3(json_file) # 假设这是你的数据加载函数
|
||
edge_index = edge_index.T
|
||
|
||
# 找到数组中等于1的元素的索引
|
||
center_index = np.where(labels == 3)
|
||
center_point = points[center_index]
|
||
|
||
# center_s = edge_index[center_index][:, 0]
|
||
|
||
adj_1_indexs = []
|
||
matching_indices_1 = [index for index, (first, _) in enumerate(edge_index) if (first == center_index )]
|
||
for matching_index in matching_indices_1:
|
||
adj_1_indexs.append(edge_index[matching_index][1])
|
||
|
||
adj_2_indexs,matching_indices_2 = find_adj(edge_index,adj_1_indexs)
|
||
adj_3_indexs,matching_indices_3 = find_adj(edge_index,adj_2_indexs)
|
||
adj_4_indexs,matching_indices_4 = find_adj(edge_index,adj_3_indexs)
|
||
|
||
combined_index = list(set(adj_1_indexs + adj_2_indexs + adj_3_indexs + adj_4_indexs))
|
||
matching_indices = list(set(matching_indices_1 + matching_indices_2 + matching_indices_3 + matching_indices_4))
|
||
print(len(combined_index))
|
||
print(len(matching_indices))
|
||
|
||
selected_point = []
|
||
selected_light = []
|
||
for i in combined_index:
|
||
selected_point.append(points[i])
|
||
selected_light.append(lights[i])
|
||
|
||
selected_point = np.array(selected_point)
|
||
selected_light = np.array(selected_light)
|
||
|
||
bg = np.zeros((img_size, img_size))
|
||
plt.figure(figsize=(9, 9))
|
||
|
||
for light, point in zip(selected_light, selected_point):
|
||
img_h, img_w = bg.shape # 背景图像的高度和宽度
|
||
h, w = point
|
||
|
||
ps = 1 # 因为 light 是 3x3 的区域,所以 ps 应该是 1
|
||
hs = np.clip(h - ps, 0, img_h - ps - 1)
|
||
ws = np.clip(w - ps, 0, img_w - ps - 1)
|
||
he = hs + 2 * ps + 1 # 3x3 区域的结束行索引
|
||
we = ws + 2 * ps + 1 # 3x3 区域的结束列索引
|
||
|
||
# 正确地将 light 重塑为 3x3 的二维数组
|
||
light_2d = light.reshape((3, 3))
|
||
|
||
# 绘制到背景图像上
|
||
bg[hs:he, ws:we] = light_2d
|
||
|
||
# 标记提取区域的中心点
|
||
bg[h, w] = 1 # 假设背景是0,将中心点设置为1以标记
|
||
|
||
|
||
plt.imshow(bg, cmap='gray')
|
||
|
||
|
||
#
|
||
# min_x = np.min(selected_point[:, 1])
|
||
# max_x = np.max(selected_point[:, 1])
|
||
# min_y = np.min(selected_point[:, 0])
|
||
# max_y = np.max(selected_point[:, 0])
|
||
#
|
||
# # 计算边界框的宽度和高度
|
||
# width = max_x - min_x
|
||
# height = max_y - min_y
|
||
#
|
||
# # 为边界框添加一些边界(例如,原始宽度和高度的10%)
|
||
# padding = 0.1
|
||
# border_x = int(width * padding)
|
||
# border_y = int(height * padding)
|
||
#
|
||
# # 调整背景图大小
|
||
# new_width = int(width + 2 * border_x)
|
||
# new_height = int(height + 2 * border_y)
|
||
# bg = np.zeros((new_height, new_width))
|
||
#
|
||
# # 调整点的坐标
|
||
# selected_point[:, 1] -= min_x - border_x
|
||
# selected_point[:, 0] -= min_y - border_y
|
||
#
|
||
# plt.figure(figsize=(new_width / 100, new_height / 100)) # 单位为英寸,100是默认的dpi
|
||
# plt.imshow(bg, cmap='gray')
|
||
#
|
||
# for point in selected_point:
|
||
# # print(points[i])
|
||
# # if labels[i] == 0:
|
||
# plt.scatter(point[1], point[0], s=18, c='white', zorder=2)
|
||
# # elif labels[i] == 1 or 2:
|
||
# # plt.scatter(points[i][1], points[i][0], s=18, c='yellow', zorder=2)
|
||
|
||
plt.axis('off')
|
||
|
||
plt.tight_layout()
|
||
# plt.tight_layout()
|
||
plt.savefig('2_.png')
|
||
plt.show()
|
||
|
||
# 遍历文件夹中的所有文件
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith('.json'):
|
||
print(filename)
|
||
json_file = os.path.join(folder_path, filename)
|
||
if filename.endswith('.jpg'):
|
||
print(filename)
|
||
img_file = os.path.join(folder_path, filename)
|
||
|
||
process_json_file(json_file,img_file)
|