atom-predict/egnn_v2/stastic_V2.py

118 lines
5.3 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from egnn_core.data import load_data_v2
import os
import json
#根据连接来遍历点的表
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
#第三步,判断普通的原子的邻居是不是也是普通的:
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
#记住,此时要把该点从表中永远排除,避免重复技术
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
# 指定文件夹路径
folder_path = '/cc/data/end-to-end-result-sv/gnn_data/after_test'
def count_num(edge_index, num_index_list):
unique_list = set() # 使用set来存储唯一的matching_index
for num_index in num_index_list:
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index]
for matching_index in matching_indices:
if matching_index not in unique_list: # 检查是否已存在
unique_list.add(matching_index) # 添加到set中
unique_list = list(unique_list)
# 如果你需要将结果转换回列表,可以使用以下代码
return len(unique_list)
def process_json_file(json_file):
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
edge_index = edge_index.T
num_4 = 0
four_list = []
have_recognized_list = []
have_van_four_list = []
for idx, (s, e) in enumerate(edge_index):
count = 0
if idx not in have_recognized_list and labels[s] == 0:
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
for matching_index in matching_indices:
have_recognized_list.append(matching_index)
current_adj = edge_index[matching_index][1]
if labels[current_adj] == 2:
count += 1
if count == 4:
four_list.append(s)
for matching_index in matching_indices:
current_adj = edge_index[matching_index][1]
if labels[current_adj] == 2:
have_van_four_list.append(current_adj)
num_4 += 1
num_3 = 0
three_list = []
have_recognized_list = []
have_van_three_list = []
for idx, (s, e) in enumerate(edge_index):
count = 0
if idx not in have_recognized_list and idx not in four_list and labels[s] == 0:
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
for matching_index in matching_indices:
have_recognized_list.append(matching_index)
current_adj = edge_index[matching_index][1]
if current_adj not in have_van_four_list and labels[current_adj] == 2:
count += 1
if count == 3:
three_list.append(s)
for matching_index in matching_indices:
current_adj = edge_index[matching_index][1]
if current_adj not in have_van_four_list and labels[current_adj] == 2:
have_van_three_list.append(current_adj)
num_3 += 1
num_2 = 0
two_list = []
have_recognized_list = []
have_van_two_list = []
for idx, (s, e) in enumerate(edge_index):
count = 0
if idx not in have_recognized_list and idx not in four_list and idx not in three_list and labels[s] == 0:
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
for matching_index in matching_indices:
have_recognized_list.append(matching_index)
current_adj = edge_index[matching_index][1]
if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2:
count += 1
if count == 2:
two_list.append(s)
for matching_index in matching_indices:
current_adj = edge_index[matching_index][1]
if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2:
have_van_two_list.append(current_adj)
num_2 += 1
num_van_type_2 = len(set(have_van_two_list))
num_van_type_3 = len(set(have_van_three_list))
num_van_type_4 = len(set(have_van_four_list))
return num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
print(filename)
json_file = os.path.join(folder_path, filename)
num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4 = process_json_file(json_file)
print(f"邻居中分别有2个3个4个线缺陷原子的正常原子数量: {num_2}, {num_3}, {num_4}")
print(f"对邻居中分别有2个3个4个线缺陷原子的正常原子,进行该类型总的邻居缺陷原子的个数的统计: {num_van_type_2}, {num_van_type_3}, {num_van_type_4}")