53 lines
2.1 KiB
Python
53 lines
2.1 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from egnn_core.data import load_data_v2
|
|
import os
|
|
import json
|
|
|
|
#根据连接来遍历点的表
|
|
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
|
|
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
|
|
#第三步,判断普通的原子的邻居是不是也是普通的:
|
|
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
|
|
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
|
|
#记住,此时要把该点从表中永远排除,避免重复技术
|
|
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
|
|
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
|
|
|
|
|
|
# 指定文件夹路径
|
|
folder_path = '/cc/data/end-to-end-result-sv/gnn_data/after_test'
|
|
|
|
def process_json_file(json_file):
|
|
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
|
|
edge_index = edge_index.T
|
|
num_2 = 0
|
|
num_3 = 0
|
|
num_4 = 0
|
|
|
|
have_recognized_list = []
|
|
for idx, (s, e) in enumerate(edge_index):
|
|
count = 0
|
|
if idx not in have_recognized_list and labels[s] == 0:
|
|
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s]
|
|
for matching_index in matching_indices:
|
|
have_recognized_list.append(matching_index)
|
|
current_adj = edge_index[matching_index][1]
|
|
if labels[current_adj] == 2:
|
|
count += 1
|
|
if count == 2:
|
|
num_2 += 1
|
|
elif count == 3:
|
|
num_3 += 1
|
|
elif count == 4:
|
|
num_4 += 1
|
|
|
|
return num_2, num_3, num_4
|
|
|
|
# 遍历文件夹中的所有文件
|
|
for filename in os.listdir(folder_path):
|
|
if filename.endswith('.json'):
|
|
print(filename)
|
|
json_file = os.path.join(folder_path, filename)
|
|
num_2, num_3, num_4 = process_json_file(json_file)
|
|
print(f"正常原子为中心,邻居中分别有2个3个4个线缺陷原子的数量: {num_2}, {num_3}, {num_4}") |