118 lines
5.3 KiB
Python
118 lines
5.3 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from egnn_core.data import load_data_v2
|
|
import os
|
|
import json
|
|
|
|
#根据连接来遍历点的表
|
|
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
|
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
|
#第三步,判断普通的原子的邻居是不是也是普通的:
|
|
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
|
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
|
#记住,此时要把该点从表中永远排除,避免重复技术
|
|
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
|
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
|
|
|
|
|
# 指定文件夹路径
|
|
folder_path = '/cc/data/end-to-end-result-sv/gnn_data/after_test'
|
|
|
|
def count_num(edge_index, num_index_list):
|
|
unique_list = set() # 使用set来存储唯一的matching_index
|
|
for num_index in num_index_list:
|
|
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index]
|
|
for matching_index in matching_indices:
|
|
if matching_index not in unique_list: # 检查是否已存在
|
|
unique_list.add(matching_index) # 添加到set中
|
|
unique_list = list(unique_list)
|
|
|
|
# 如果你需要将结果转换回列表,可以使用以下代码
|
|
return len(unique_list)
|
|
def process_json_file(json_file):
|
|
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
|
|
edge_index = edge_index.T
|
|
|
|
num_4 = 0
|
|
four_list = []
|
|
have_recognized_list = []
|
|
have_van_four_list = []
|
|
|
|
for idx, (s, e) in enumerate(edge_index):
|
|
count = 0
|
|
if idx not in have_recognized_list and labels[s] == 0:
|
|
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
|
|
for matching_index in matching_indices:
|
|
have_recognized_list.append(matching_index)
|
|
current_adj = edge_index[matching_index][1]
|
|
if labels[current_adj] == 2:
|
|
count += 1
|
|
if count == 4:
|
|
four_list.append(s)
|
|
for matching_index in matching_indices:
|
|
current_adj = edge_index[matching_index][1]
|
|
if labels[current_adj] == 2:
|
|
have_van_four_list.append(current_adj)
|
|
num_4 += 1
|
|
|
|
|
|
num_3 = 0
|
|
three_list = []
|
|
have_recognized_list = []
|
|
have_van_three_list = []
|
|
|
|
for idx, (s, e) in enumerate(edge_index):
|
|
count = 0
|
|
if idx not in have_recognized_list and idx not in four_list and labels[s] == 0:
|
|
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
|
|
for matching_index in matching_indices:
|
|
have_recognized_list.append(matching_index)
|
|
current_adj = edge_index[matching_index][1]
|
|
if current_adj not in have_van_four_list and labels[current_adj] == 2:
|
|
count += 1
|
|
if count == 3:
|
|
three_list.append(s)
|
|
for matching_index in matching_indices:
|
|
current_adj = edge_index[matching_index][1]
|
|
if current_adj not in have_van_four_list and labels[current_adj] == 2:
|
|
have_van_three_list.append(current_adj)
|
|
num_3 += 1
|
|
|
|
|
|
num_2 = 0
|
|
two_list = []
|
|
have_recognized_list = []
|
|
have_van_two_list = []
|
|
|
|
for idx, (s, e) in enumerate(edge_index):
|
|
count = 0
|
|
if idx not in have_recognized_list and idx not in four_list and idx not in three_list and labels[s] == 0:
|
|
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
|
|
for matching_index in matching_indices:
|
|
have_recognized_list.append(matching_index)
|
|
current_adj = edge_index[matching_index][1]
|
|
if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2:
|
|
count += 1
|
|
if count == 2:
|
|
two_list.append(s)
|
|
for matching_index in matching_indices:
|
|
current_adj = edge_index[matching_index][1]
|
|
if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2:
|
|
have_van_two_list.append(current_adj)
|
|
num_2 += 1
|
|
|
|
|
|
num_van_type_2 = len(set(have_van_two_list))
|
|
num_van_type_3 = len(set(have_van_three_list))
|
|
num_van_type_4 = len(set(have_van_four_list))
|
|
|
|
return num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4
|
|
|
|
# 遍历文件夹中的所有文件
|
|
for filename in os.listdir(folder_path):
|
|
if filename.endswith('.json'):
|
|
print(filename)
|
|
json_file = os.path.join(folder_path, filename)
|
|
num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4 = process_json_file(json_file)
|
|
print(f"邻居中分别有2个3个4个线缺陷原子的正常原子数量: {num_2}, {num_3}, {num_4}")
|
|
print(f"对邻居中分别有2个3个4个线缺陷原子的正常原子,进行该类型总的邻居缺陷原子的个数的统计: {num_van_type_2}, {num_van_type_3}, {num_van_type_4}") |