atom-predict/egnn_v2/stastic_V3.py

39 lines
1.6 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from egnn_core.data import load_data_v2
import os
import json
#根据连接来遍历点的表
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
#第三步,判断普通的原子的邻居是不是也是普通的:
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
#记住,此时要把该点从表中永远排除,避免重复技术
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
# 指定文件夹路径
folder_path = '/cc/data/end-to-end-result-sv/gnn_data/after_test'
def process_json_file(json_file):
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
label_array = np.array(labels)
# 计算等于1、2、3的数的数量
count_0 = len(label_array)
count_1 = np.sum(label_array == 0)
count_2 = np.sum(label_array == 1)
count_3 = np.sum(label_array == 2)
return count_0,count_1, count_2, count_3
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
print(filename)
json_file = os.path.join(folder_path, filename)
num_1, num_2, num_3, num_4 = process_json_file(json_file)
print(f"总数\正常\\线的数量: {num_1},{num_2}, {num_3}, {num_4}")