atom-predict/egnn_v2/stastic_center_v2.py

113 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib.pyplot as plt
from egnn_core.data import load_data_v3,load_data_v4
import os
import json
#根据连接来遍历点的表
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
#第三步,判断普通的原子的邻居是不是也是普通的:
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
#记住,此时要把该点从表中永远排除,避免重复技术
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
# 指定文件夹路径
folder_path = '/home/gao/mouclear/analyze_center/2'
img_size = 2048
def find_adj(edge_index,adj_indexs):
matching_indices = []
for matching_index in adj_indexs:
matching_indices.append([index for index, (first, _) in enumerate(edge_index) if (first == matching_index)])
matching_indices = [item for sublist in matching_indices if sublist for item in sublist]
adj_indexs = []
for matching_index in matching_indices:
adj_indexs.append(edge_index[matching_index][1])
return adj_indexs, matching_indices
def process_json_file(json_file):
points, edge_index, labels, lights = load_data_v3(json_file) # 假设这是你的数据加载函数
edge_index = edge_index.T
# 找到数组中等于1的元素的索引
center_index = np.where(labels == 3)
center_point = points[center_index]
# center_s = edge_index[center_index][:, 0]
adj_1_indexs = []
matching_indices_1 = [index for index, (first, _) in enumerate(edge_index) if (first == center_index )]
for matching_index in matching_indices_1:
adj_1_indexs.append(edge_index[matching_index][1])
adj_2_indexs,matching_indices_2 = find_adj(edge_index,adj_1_indexs)
adj_3_indexs,matching_indices_3 = find_adj(edge_index,adj_2_indexs)
adj_4_indexs,matching_indices_4 = find_adj(edge_index,adj_3_indexs)
combined_index = list(set(adj_1_indexs + adj_2_indexs + adj_3_indexs + adj_4_indexs))
matching_indices = list(set(matching_indices_1 + matching_indices_2 + matching_indices_3 + matching_indices_4))
print(len(combined_index))
print(len(matching_indices))
selected_point = []
selected_light = []
for i in combined_index:
selected_point.append(points[i])
selected_light.append(lights[i])
selected_point = np.array(selected_point)
min_x = np.min(selected_point[:, 1])
max_x = np.max(selected_point[:, 1])
min_y = np.min(selected_point[:, 0])
max_y = np.max(selected_point[:, 0])
# 计算边界框的宽度和高度
width = max_x - min_x
height = max_y - min_y
# 为边界框添加一些边界例如原始宽度和高度的10%
padding = 0.1
border_x = int(width * padding)
border_y = int(height * padding)
# 调整背景图大小
new_width = int(width + 2 * border_x)
new_height = int(height + 2 * border_y)
bg = np.zeros((new_height, new_width))
# 调整点的坐标
selected_point[:, 1] -= min_x - border_x
selected_point[:, 0] -= min_y - border_y
plt.figure(figsize=(new_width / 100, new_height / 100)) # 单位为英寸100是默认的dpi
plt.imshow(bg, cmap='gray')
for point in selected_point:
# print(points[i])
# if labels[i] == 0:
plt.scatter(point[1], point[0], s=18, c='white', zorder=2)
# elif labels[i] == 1 or 2:
# plt.scatter(points[i][1], points[i][0], s=18, c='yellow', zorder=2)
plt.axis('off')
plt.tight_layout()
# plt.tight_layout()
plt.savefig('2_.png')
plt.show()
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
print(filename)
json_file = os.path.join(folder_path, filename)
process_json_file(json_file)