atom-predict/egnn_v2/stastic_V2_aug_plt.py

192 lines
7.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import math
import numpy as np
import matplotlib.pyplot as plt
from egnn_core.data import load_data_v2
import os
import json
#根据连接来遍历点的表
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
#第三步,判断普通的原子的邻居是不是也是普通的:
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
#记住,此时要把该点从表中永远排除,避免重复技术
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
# 指定文件夹路径
folder_path = '/home/gao/mouclear/cc/final_todo/line'
def count_num(edge_index, num_index_list):
unique_list = set() # 使用set来存储唯一的matching_index
for num_index in num_index_list:
matching_indices = [index for index, (first, _) in enumerate(edge_index) if first == num_index]
for matching_index in matching_indices:
if matching_index not in unique_list: # 检查是否已存在
unique_list.add(matching_index) # 添加到set中
unique_list = list(unique_list)
# 如果你需要将结果转换回列表,可以使用以下代码
return len(unique_list)
def process_json_file(json_file):
points, edge_index, labels, _ = load_data_v2(json_file) # 假设这是你的数据加载函数
print(len(labels))
c = (labels == 0)
a = (labels == 1)
b = (labels == 2)
print(np.sum(c))
print(np.sum(a))
print(np.sum(b))
edge_index = edge_index.T
num_10 = 0
num_9 = 0
num_8 = 0
num_7 = 0
num_6 = 0
num_5 = 0
num_4 = 0
num_3 = 0
num_2 = 0
num_1 = 0
num_0 = 0
list_0 = []
list_1 = []
list_2 = []
list_3 = []
list_4 = []
list_5 = []
list_6 = []
list_7 = []
have_recognized_list = []
for idx, (s, e) in enumerate(edge_index):
if s not in have_recognized_list:
if labels[s] == 1 or labels[s] == 2:
count = 0
# matching_indices = [index for index, (first, _) in enumerate(edge_index) if (first == s or _ == s)]
matching_indices = [index for index, (first, _) in enumerate(edge_index) if (first == s)]
adj_point_list = points[edge_index[matching_indices][:,1]]
have_recognized_adj_adj_list = []
for matching_index in matching_indices:
# have_recognized_list.append(matching_index)
#邻居编号
current_adj = edge_index[matching_index][1]
#邻居的邻居索引
# adj_adj = [index for index, (first, _) in enumerate(edge_index) if (first == current_adj or _ == current_adj)]
adj_adj = [index for index, (first, _) in enumerate(edge_index) if (first == current_adj)]
for i in adj_adj:
if (labels[edge_index[i][1]] == 2 or labels[edge_index[i][1]] == 1) and edge_index[i][1] != s:
adj_adj_point = points[edge_index[i][1]]
# if not np.any(np.all(adj_point_list == adj_adj_point.reshape(1, -1), axis=1)):
if have_recognized_adj_adj_list == []:
cur_point = points[s]
dist = math.sqrt(
(adj_adj_point[0] - cur_point[0]) ** 2 + (adj_adj_point[1] - cur_point[1]) ** 2)
if dist <= 52.5:
count += 1
elif not np.any(np.all(have_recognized_adj_adj_list == adj_adj_point.reshape(1, -1), axis=1)) :
cur_point = points[s]
dist = math.sqrt((adj_adj_point[0] - cur_point[0])**2 + (adj_adj_point[1]- cur_point[1])**2)
if dist <= 52.5:
count += 1
have_recognized_adj_adj_list.append(adj_adj_point)
# count+=1
if count == 7:
num_7 += 1
list_7.append(points[s])
if count ==6:
num_6 += 1
list_6.append(points[s])
if count ==5:
num_5 += 1
list_5.append(points[s])
if count ==4:
num_4 += 1
list_4.append(points[s])
if count ==3:
num_3 += 1
list_3.append(points[s])
if count ==2:
num_2 += 1
list_2.append(points[s])
if count ==1:
num_1 += 1
list_1.append(points[s])
if count ==0:
num_0 += 1
list_0.append(points[s])
have_recognized_list.append(s)
mask_pd = np.zeros((1024, 1024))
# mask_pd = np.zeros((512, 512))
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
plt.figure(figsize=(9, 9))
# plt.imshow(img, cmap='gray')
# c = ['#9BB6CF', '#76F1A2', '#EDC08C']
# for i in range(4):
# h, w = np.where(mask_pd == i + 1)
# plt.scatter(w, h, s=23, c=c[i])
if len(list_0) > 0:
for i in list_0:
plt.scatter(i[1], i[0], s=23, c='#f0514f')
if len(list_1) > 0:
for i in list_1:
plt.scatter(i[1], i[0], s=23, c='#f39156')
if len(list_2) > 0:
for i in list_2:
plt.scatter(i[1], i[0], s=23, c='#fee358')
if len(list_3) > 0:
for i in list_3:
plt.scatter(i[1], i[0], s=23, c='#b5fa9f')
if len(list_4) > 0:
for i in list_4:
plt.scatter(i[1], i[0], s=23, c='#50ecf8')
if len(list_5) > 0:
for i in list_5:
plt.scatter(i[1], i[0], s=23, c='#5277ef')
if len(list_6) > 0:
for i in list_6:
plt.scatter(i[1], i[0], s=23, c='#4e4aaa')
if len(list_7) > 0:
for i in list_7:
plt.scatter(i[1], i[0], s=23, c='#4e4aaa')
# for idx, (s, e) in enumerate(edge_index):
# s = points[s]
# e = points[e]
# plt.plot([s[1], e[1]], [s[0], e[0]], linewidth=1, c='#C0C0C0', zorder=1)
plt.axis('off')
plt.tight_layout()
# 保存图像
img_filename = json_file.replace('.json', '.png')
output_path = os.path.join(img_filename)
plt.savefig(output_path, dpi=300)
return num_0,num_1,num_2, num_3, num_4, num_5, num_6, num_7, num_8, num_9, num_10
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
print(filename)
json_file = os.path.join(folder_path, filename)
num_0,num_1,num_2, num_3, num_4, num_5, num_6,num_7, num_8, num_9, num_10 = process_json_file(json_file)
print(f"邻居的邻居中分别有0个1个2个3个4个5个6个7个8个9个10个缺陷原子的缺陷原子数量: {num_0}, {num_1}, {num_2}{num_3}, {num_4}, {num_5}{num_6}, {num_7}, {num_8}, {num_9}, {num_10}")