192 lines
7.3 KiB
Python
192 lines
7.3 KiB
Python
import math
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from egnn_core.data import load_data_v2
|
||
import os
|
||
import json
|
||
|
||
#根据连接来遍历点的表
|
||
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
||
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
||
#第三步,判断普通的原子的邻居是不是也是普通的:
|
||
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
||
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
||
#记住,此时要把该点从表中永远排除,避免重复技术
|
||
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
||
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
||
|
||
|
||
# 指定文件夹路径
|
||
folder_path = '/home/gao/mouclear/cc/final_todo/line'
|
||
|
||
def count_num(edge_index, num_index_list):
|
||
unique_list = set() # 使用set来存储唯一的matching_index
|
||
for num_index in num_index_list:
|
||
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index]
|
||
for matching_index in matching_indices:
|
||
if matching_index not in unique_list: # 检查是否已存在
|
||
unique_list.add(matching_index) # 添加到set中
|
||
unique_list = list(unique_list)
|
||
|
||
# 如果你需要将结果转换回列表,可以使用以下代码
|
||
return len(unique_list)
|
||
|
||
def process_json_file(json_file):
|
||
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
|
||
print(len(labels))
|
||
c = (labels == 0)
|
||
a = (labels == 1)
|
||
b = (labels == 2)
|
||
print(np.sum(c))
|
||
print(np.sum(a))
|
||
print(np.sum(b))
|
||
|
||
edge_index = edge_index.T
|
||
num_10 = 0
|
||
num_9 = 0
|
||
num_8 = 0
|
||
num_7 = 0
|
||
num_6 = 0
|
||
num_5 = 0
|
||
num_4 = 0
|
||
num_3 = 0
|
||
num_2 = 0
|
||
num_1 = 0
|
||
num_0 = 0
|
||
|
||
list_0 = []
|
||
list_1 = []
|
||
list_2 = []
|
||
list_3 = []
|
||
list_4 = []
|
||
list_5 = []
|
||
list_6 = []
|
||
list_7 = []
|
||
have_recognized_list = []
|
||
for idx, (s, e) in enumerate(edge_index):
|
||
if s not in have_recognized_list:
|
||
if labels[s] == 1 or labels[s] == 2:
|
||
count = 0
|
||
# matching_indices = [index for index, (first, _) in enumerate(edge_index) if (first == s or _ == s)]
|
||
matching_indices = [index for index, (first, _) in enumerate(edge_index) if (first == s)]
|
||
adj_point_list = points[edge_index[matching_indices][:,1]]
|
||
have_recognized_adj_adj_list = []
|
||
for matching_index in matching_indices:
|
||
# have_recognized_list.append(matching_index)
|
||
#邻居编号
|
||
current_adj = edge_index[matching_index][1]
|
||
#邻居的邻居索引
|
||
# adj_adj = [index for index, (first, _) in enumerate(edge_index) if (first == current_adj or _ == current_adj)]
|
||
adj_adj = [index for index, (first, _) in enumerate(edge_index) if (first == current_adj)]
|
||
|
||
for i in adj_adj:
|
||
if (labels[edge_index[i][1]] == 2 or labels[edge_index[i][1]] == 1) and edge_index[i][1] != s:
|
||
adj_adj_point = points[edge_index[i][1]]
|
||
# if not np.any(np.all(adj_point_list == adj_adj_point.reshape(1, -1), axis=1)):
|
||
if have_recognized_adj_adj_list == []:
|
||
cur_point = points[s]
|
||
dist = math.sqrt(
|
||
(adj_adj_point[0] - cur_point[0]) ** 2 + (adj_adj_point[1] - cur_point[1]) ** 2)
|
||
if dist <= 52.5:
|
||
count += 1
|
||
elif not np.any(np.all(have_recognized_adj_adj_list == adj_adj_point.reshape(1, -1), axis=1)) :
|
||
cur_point = points[s]
|
||
dist = math.sqrt((adj_adj_point[0] - cur_point[0])**2 + (adj_adj_point[1]- cur_point[1])**2)
|
||
if dist <= 52.5:
|
||
count += 1
|
||
have_recognized_adj_adj_list.append(adj_adj_point)
|
||
# count+=1
|
||
|
||
if count == 7:
|
||
num_7 += 1
|
||
list_7.append(points[s])
|
||
if count ==6:
|
||
num_6 += 1
|
||
list_6.append(points[s])
|
||
if count ==5:
|
||
num_5 += 1
|
||
list_5.append(points[s])
|
||
if count ==4:
|
||
num_4 += 1
|
||
list_4.append(points[s])
|
||
if count ==3:
|
||
num_3 += 1
|
||
list_3.append(points[s])
|
||
if count ==2:
|
||
num_2 += 1
|
||
list_2.append(points[s])
|
||
if count ==1:
|
||
num_1 += 1
|
||
list_1.append(points[s])
|
||
if count ==0:
|
||
num_0 += 1
|
||
list_0.append(points[s])
|
||
|
||
have_recognized_list.append(s)
|
||
|
||
mask_pd = np.zeros((1024, 1024))
|
||
# mask_pd = np.zeros((512, 512))
|
||
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
||
mask_pd = np.array(mask_pd, np.uint8)
|
||
|
||
plt.figure(figsize=(9, 9))
|
||
# plt.imshow(img, cmap='gray')
|
||
|
||
# c = ['#9BB6CF', '#76F1A2', '#EDC08C']
|
||
# for i in range(4):
|
||
# h, w = np.where(mask_pd == i + 1)
|
||
# plt.scatter(w, h, s=23, c=c[i])
|
||
|
||
if len(list_0) > 0:
|
||
for i in list_0:
|
||
plt.scatter(i[1], i[0], s=23, c='#f0514f')
|
||
if len(list_1) > 0:
|
||
for i in list_1:
|
||
plt.scatter(i[1], i[0], s=23, c='#f39156')
|
||
if len(list_2) > 0:
|
||
for i in list_2:
|
||
plt.scatter(i[1], i[0], s=23, c='#fee358')
|
||
if len(list_3) > 0:
|
||
for i in list_3:
|
||
plt.scatter(i[1], i[0], s=23, c='#b5fa9f')
|
||
if len(list_4) > 0:
|
||
for i in list_4:
|
||
plt.scatter(i[1], i[0], s=23, c='#50ecf8')
|
||
if len(list_5) > 0:
|
||
for i in list_5:
|
||
plt.scatter(i[1], i[0], s=23, c='#5277ef')
|
||
if len(list_6) > 0:
|
||
for i in list_6:
|
||
plt.scatter(i[1], i[0], s=23, c='#4e4aaa')
|
||
if len(list_7) > 0:
|
||
for i in list_7:
|
||
plt.scatter(i[1], i[0], s=23, c='#4e4aaa')
|
||
|
||
|
||
|
||
# for idx, (s, e) in enumerate(edge_index):
|
||
# s = points[s]
|
||
# e = points[e]
|
||
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
|
||
|
||
plt.axis('off')
|
||
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
img_filename = json_file.replace('.json', '.png')
|
||
output_path = os.path.join(img_filename)
|
||
plt.savefig(output_path, dpi=300)
|
||
|
||
|
||
|
||
return num_0,num_1,num_2, num_3, num_4, num_5, num_6, num_7, num_8, num_9, num_10
|
||
|
||
# 遍历文件夹中的所有文件
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith('.json'):
|
||
print(filename)
|
||
json_file = os.path.join(folder_path, filename)
|
||
num_0,num_1,num_2, num_3, num_4, num_5, num_6,num_7, num_8, num_9, num_10 = process_json_file(json_file)
|
||
print(f"邻居的邻居中分别有0个1个2个3个4个5个6个7个8个9个10个缺陷原子的缺陷原子数量: {num_0}, {num_1}, {num_2},{num_3}, {num_4}, {num_5},{num_6}, {num_7}, {num_8}, {num_9}, {num_10}") |