import numpy as np import matplotlib.pyplot as plt from egnn_core.data import load_data_v2 import os import json #根据连接来遍历点的表 #第一步,遍历edge index.T里面的第一列对应的点的序号[:,0] #第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的 #第三步,判断普通的原子的邻居是不是也是普通的: #先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些 #去判断这些邻居有多少个属于线缺陷,即label是2,并计数 #记住,此时要把该点从表中永远排除,避免重复技术 # edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]]) # labels = np.array([0,0,0,2,0,2,0,2,0,0,2]) # 指定文件夹路径 folder_path = '/cc/data/end-to-end-result-sv/gnn_data/after_test' def count_num(edge_index, num_index_list): unique_list = set() # 使用set来存储唯一的matching_index for num_index in num_index_list: matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index] for matching_index in matching_indices: if matching_index not in unique_list: # 检查是否已存在 unique_list.add(matching_index) # 添加到set中 unique_list = list(unique_list) # 如果你需要将结果转换回列表,可以使用以下代码 return len(unique_list) def process_json_file(json_file): points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数 edge_index = edge_index.T num_4 = 0 four_list = [] have_recognized_list = [] have_van_four_list = [] for idx, (s, e) in enumerate(edge_index): count = 0 if idx not in have_recognized_list and labels[s] == 0: matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s] for matching_index in matching_indices: have_recognized_list.append(matching_index) current_adj = edge_index[matching_index][1] if labels[current_adj] == 2: count += 1 if count == 4: four_list.append(s) for matching_index in matching_indices: current_adj = edge_index[matching_index][1] if labels[current_adj] == 2: have_van_four_list.append(current_adj) num_4 += 1 num_3 = 0 three_list = [] have_recognized_list = [] have_van_three_list = [] for idx, (s, e) in enumerate(edge_index): count = 0 if idx not in have_recognized_list and idx not in four_list and labels[s] == 0: matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s] for matching_index in matching_indices: have_recognized_list.append(matching_index) current_adj = edge_index[matching_index][1] if current_adj not in have_van_four_list and labels[current_adj] == 2: count += 1 if count == 3: three_list.append(s) for matching_index in matching_indices: current_adj = edge_index[matching_index][1] if current_adj not in have_van_four_list and labels[current_adj] == 2: have_van_three_list.append(current_adj) num_3 += 1 num_2 = 0 two_list = [] have_recognized_list = [] have_van_two_list = [] for idx, (s, e) in enumerate(edge_index): count = 0 if idx not in have_recognized_list and idx not in four_list and idx not in three_list and labels[s] == 0: matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == s] for matching_index in matching_indices: have_recognized_list.append(matching_index) current_adj = edge_index[matching_index][1] if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2: count += 1 if count == 2: two_list.append(s) for matching_index in matching_indices: current_adj = edge_index[matching_index][1] if current_adj not in have_van_four_list and current_adj not in have_van_three_list and labels[current_adj] == 2: have_van_two_list.append(current_adj) num_2 += 1 num_van_type_2 = len(set(have_van_two_list)) num_van_type_3 = len(set(have_van_three_list)) num_van_type_4 = len(set(have_van_four_list)) return num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4 # 遍历文件夹中的所有文件 for filename in os.listdir(folder_path): if filename.endswith('.json'): print(filename) json_file = os.path.join(folder_path, filename) num_2, num_3, num_4,num_van_type_2, num_van_type_3, num_van_type_4 = process_json_file(json_file) print(f"邻居中分别有2个3个4个线缺陷原子的正常原子数量: {num_2}, {num_3}, {num_4}") print(f"对邻居中分别有2个3个4个线缺陷原子的正常原子,进行该类型总的邻居缺陷原子的个数的统计: {num_van_type_2}, {num_van_type_3}, {num_van_type_4}")