98 lines
3.6 KiB
Python
98 lines
3.6 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from egnn_core.data import load_data_v3,load_data_v4
|
|
import os
|
|
import json
|
|
|
|
#根据连接来遍历点的表
|
|
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
|
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
|
#第三步,判断普通的原子的邻居是不是也是普通的:
|
|
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
|
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
|
#记住,此时要把该点从表中永远排除,避免重复技术
|
|
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
|
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
|
|
|
|
|
# 指定文件夹路径
|
|
folder_path = '/home/gao/mouclear/analyze_center/2'
|
|
img_size = 2048
|
|
|
|
def find_adj(edge_index,adj_indexs):
|
|
matching_indices = []
|
|
for matching_index in adj_indexs:
|
|
matching_indices.append([index for index, (first, _) in enumerate(edge_index) if (first == matching_index)])
|
|
|
|
matching_indices = [item for sublist in matching_indices if sublist for item in sublist]
|
|
|
|
adj_indexs = []
|
|
for matching_index in matching_indices:
|
|
adj_indexs.append(edge_index[matching_index][1])
|
|
return adj_indexs, matching_indices
|
|
|
|
|
|
def process_json_file(json_file):
|
|
points, edge_index, labels, _ = load_data_v3(json_file) # 假设这是你的数据加载函数
|
|
edge_index = edge_index.T
|
|
|
|
# 找到数组中等于1的元素的索引
|
|
center_index = np.where(labels == 3)
|
|
center_point = points[center_index]
|
|
|
|
# center_s = edge_index[center_index][:, 0]
|
|
|
|
adj_1_indexs = []
|
|
matching_indices_1 = [index for index, (first, _) in enumerate(edge_index) if (first == center_index )]
|
|
for matching_index in matching_indices_1:
|
|
adj_1_indexs.append(edge_index[matching_index][1])
|
|
|
|
adj_2_indexs,matching_indices_2 = find_adj(edge_index,adj_1_indexs)
|
|
adj_3_indexs,matching_indices_3 = find_adj(edge_index,adj_2_indexs)
|
|
adj_4_indexs,matching_indices_4 = find_adj(edge_index,adj_3_indexs)
|
|
|
|
combined_index = list(set(adj_1_indexs + adj_2_indexs + adj_3_indexs + adj_4_indexs))
|
|
matching_indices = list(set(matching_indices_1 + matching_indices_2 + matching_indices_3 + matching_indices_4))
|
|
print(len(combined_index))
|
|
print(len(matching_indices))
|
|
|
|
bg = np.zeros((img_size, img_size))
|
|
# bg[0, 0] = 0
|
|
|
|
# mask_pd = np.zeros((img_size, img_size))
|
|
# mask_pd[points[:, 0], points[:, 1]] = 1
|
|
# mask_pd = np.array(mask_pd, np.uint8)
|
|
|
|
plt.figure(figsize=(9, 9))
|
|
plt.imshow(bg, cmap='gray')
|
|
# h, w = np.where(mask_pd != 0)
|
|
# plt.scatter(points[:, 1], points[:, 0], s=16, c='#8EB9D9', zorder=2)
|
|
for i in combined_index:
|
|
print(points[i])
|
|
if labels[i] == 0:
|
|
plt.scatter(points[i][1], points[i][0], s=18, c='white', zorder=2)
|
|
elif labels[i] == 1 or 2:
|
|
plt.scatter(points[i][1], points[i][0], s=18, c='yellow', zorder=2)
|
|
|
|
plt.scatter(center_point[:, 1], center_point[:, 0], s=18, c='red', zorder=2)
|
|
|
|
for id in matching_indices:
|
|
s = points[edge_index[id][0]]
|
|
e = points[edge_index[id][1]]
|
|
plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
|
|
|
plt.axis('off')
|
|
|
|
plt.tight_layout()
|
|
plt.savefig('2.png')
|
|
plt.show()
|
|
|
|
# return count_0,count_1, count_2, count_3
|
|
|
|
# 遍历文件夹中的所有文件
|
|
for filename in os.listdir(folder_path):
|
|
if filename.endswith('.json'):
|
|
print(filename)
|
|
json_file = os.path.join(folder_path, filename)
|
|
process_json_file(json_file)
|