113 lines
4.4 KiB
Python
113 lines
4.4 KiB
Python
import math
|
||
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from egnn_core.data import load_data_v2
|
||
import os
|
||
import json
|
||
|
||
#根据连接来遍历点的表
|
||
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
||
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
||
#第三步,判断普通的原子的邻居是不是也是普通的:
|
||
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
||
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
||
#记住,此时要把该点从表中永远排除,避免重复技术
|
||
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
||
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
||
|
||
|
||
# 指定文件夹路径
|
||
folder_path = '/home/gao/mouclear/cc/final_todo/SV/otherresult'
|
||
|
||
def count_num(edge_index, num_index_list):
|
||
unique_list = set() # 使用set来存储唯一的matching_index
|
||
for num_index in num_index_list:
|
||
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index]
|
||
for matching_index in matching_indices:
|
||
if matching_index not in unique_list: # 检查是否已存在
|
||
unique_list.add(matching_index) # 添加到set中
|
||
unique_list = list(unique_list)
|
||
|
||
# 如果你需要将结果转换回列表,可以使用以下代码
|
||
return len(unique_list)
|
||
|
||
def process_json_file(json_file):
|
||
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
|
||
edge_index = edge_index.T
|
||
num_10 = 0
|
||
num_9 = 0
|
||
num_8 = 0
|
||
num_7 = 0
|
||
num_6 = 0
|
||
num_5 = 0
|
||
num_4 = 0
|
||
num_3 = 0
|
||
num_2 = 0
|
||
num_1 = 0
|
||
num_0 = 0
|
||
have_recognized_list = []
|
||
|
||
a = (labels == 1)
|
||
b = (labels == 2)
|
||
print(np.sum(a), np.sum(b))
|
||
|
||
for idx, (s, e) in enumerate(edge_index):
|
||
if s not in have_recognized_list:
|
||
|
||
if labels[s] == 1 or labels[s] == 2:
|
||
count = 0
|
||
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
|
||
have_recognized_adj_adj_list = []
|
||
for matching_index in matching_indices:
|
||
# have_recognized_list.append(matching_index)
|
||
#邻居编号
|
||
current_adj = edge_index[matching_index][1]
|
||
#邻居的邻居索引
|
||
adj_adj = [index for index, (first, _) in enumerate(edge_index) if first == current_adj]
|
||
for i in adj_adj:
|
||
if (labels[edge_index[i][1]] == 2 or labels[edge_index[i][1]] == 1) and i not in have_recognized_list and edge_index[i][1] != s:
|
||
adj_adj_point = points[edge_index[i][1]]
|
||
cur_point = points[s]
|
||
dist = math.sqrt((adj_adj_point[0] - cur_point[0])**2 + (adj_adj_point[1]- cur_point[1])**2)
|
||
if dist <= 40:
|
||
count += 1
|
||
have_recognized_adj_adj_list.append(i)
|
||
# count+=1
|
||
|
||
if count ==10:
|
||
num_10 += 1
|
||
if count ==9:
|
||
num_9 += 1
|
||
if count ==8:
|
||
num_8 += 1
|
||
if count ==7:
|
||
num_7 += 1
|
||
if count ==6:
|
||
num_6 += 1
|
||
if count ==5:
|
||
num_5 += 1
|
||
if count ==4:
|
||
num_4 += 1
|
||
if count ==3:
|
||
num_3 += 1
|
||
if count ==2:
|
||
num_2 += 1
|
||
if count ==1:
|
||
num_1 += 1
|
||
if count ==0:
|
||
num_0 += 1
|
||
|
||
have_recognized_list.append(s)
|
||
|
||
|
||
|
||
return num_0,num_1,num_2, num_3, num_4, num_5, num_6, num_7, num_8, num_9, num_10
|
||
|
||
# 遍历文件夹中的所有文件
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith('.json'):
|
||
print(filename)
|
||
json_file = os.path.join(folder_path, filename)
|
||
num_0,num_1,num_2, num_3, num_4, num_5, num_6,num_7, num_8, num_9, num_10 = process_json_file(json_file)
|
||
print(f"邻居的邻居中分别有0个1个2个3个4个5个6个7个8个9个10个缺陷原子的缺陷原子数量: {num_0}, {num_1}, {num_2},{num_3}, {num_4}, {num_5},{num_6}, {num_7}, {num_8}, {num_9}, {num_10}") |