atom-predict/egnn_v2/stastic_center_v3.py

163 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import matplotlib.pyplot as plt
from egnn_core.data import load_data_v3,load_data_v4
import os
import json
#根据连接来遍历点的表
#第一步,遍历edge index.T里面的第一列对应的点的序号[:,0]
#第二步,根据序号去找对应的label是否为普通,即0,如果等于0进入下一步,不然跳过;这一步是为了排除掉非普通原子的
#第三步,判断普通的原子的邻居是不是也是普通的:
#先找edge index.T里面的第一列对应的点的序号[:,1],即邻居有哪些
#去判断这些邻居有多少个属于线缺陷,即label是2,并计数
#记住,此时要把该点从表中永远排除,避免重复技术
# edge_index = np.array([[1, 3], [4, 5], [1, 7], [1, 10]])
# labels = np.array([0,0,0,2,0,2,0,2,0,0,2])
# 指定文件夹路径
folder_path = '/home/gao/mouclear/analyze_center/2'
img_size = 2048
def draw_light(img, point, light, ps=1):
img_h, img_w = img.shape
h, w = point
# 计算窗口的边界
hs = np.clip(h - ps, 0, img_h - 1)
ws = np.clip(w - ps, 0, img_w - 1)
he = np.clip(h + ps, 0, img_h - 1)
we = np.clip(w + ps, 0, img_w - 1)
# 从light中提取窗口区域的像素值
window = light.reshape((he - hs, we - ws))
# 绘制窗口到背景图像上
img[hs:he, ws:we] = window
return img
def find_adj(edge_index,adj_indexs):
matching_indices = []
for matching_index in adj_indexs:
matching_indices.append([index for index, (first, _) in enumerate(edge_index) if (first == matching_index)])
matching_indices = [item for sublist in matching_indices if sublist for item in sublist]
adj_indexs = []
for matching_index in matching_indices:
adj_indexs.append(edge_index[matching_index][1])
return adj_indexs, matching_indices
def process_json_file(json_file,img_file):
points, edge_index, labels, lights = load_data_v3(json_file) # 假设这是你的数据加载函数
edge_index = edge_index.T
# 找到数组中等于1的元素的索引
center_index = np.where(labels == 3)
center_point = points[center_index]
# center_s = edge_index[center_index][:, 0]
adj_1_indexs = []
matching_indices_1 = [index for index, (first, _) in enumerate(edge_index) if (first == center_index )]
for matching_index in matching_indices_1:
adj_1_indexs.append(edge_index[matching_index][1])
adj_2_indexs,matching_indices_2 = find_adj(edge_index,adj_1_indexs)
adj_3_indexs,matching_indices_3 = find_adj(edge_index,adj_2_indexs)
adj_4_indexs,matching_indices_4 = find_adj(edge_index,adj_3_indexs)
combined_index = list(set(adj_1_indexs + adj_2_indexs + adj_3_indexs + adj_4_indexs))
matching_indices = list(set(matching_indices_1 + matching_indices_2 + matching_indices_3 + matching_indices_4))
print(len(combined_index))
print(len(matching_indices))
selected_point = []
selected_light = []
for i in combined_index:
selected_point.append(points[i])
selected_light.append(lights[i])
selected_point = np.array(selected_point)
selected_light = np.array(selected_light)
bg = np.zeros((img_size, img_size))
plt.figure(figsize=(9, 9))
for light, point in zip(selected_light, selected_point):
img_h, img_w = bg.shape # 背景图像的高度和宽度
h, w = point
ps = 1 # 因为 light 是 3x3 的区域,所以 ps 应该是 1
hs = np.clip(h - ps, 0, img_h - ps - 1)
ws = np.clip(w - ps, 0, img_w - ps - 1)
he = hs + 2 * ps + 1 # 3x3 区域的结束行索引
we = ws + 2 * ps + 1 # 3x3 区域的结束列索引
# 正确地将 light 重塑为 3x3 的二维数组
light_2d = light.reshape((3, 3))
# 绘制到背景图像上
bg[hs:he, ws:we] = light_2d
# 标记提取区域的中心点
bg[h, w] = 1 # 假设背景是0将中心点设置为1以标记
plt.imshow(bg, cmap='gray')
#
# min_x = np.min(selected_point[:, 1])
# max_x = np.max(selected_point[:, 1])
# min_y = np.min(selected_point[:, 0])
# max_y = np.max(selected_point[:, 0])
#
# # 计算边界框的宽度和高度
# width = max_x - min_x
# height = max_y - min_y
#
# # 为边界框添加一些边界例如原始宽度和高度的10%
# padding = 0.1
# border_x = int(width * padding)
# border_y = int(height * padding)
#
# # 调整背景图大小
# new_width = int(width + 2 * border_x)
# new_height = int(height + 2 * border_y)
# bg = np.zeros((new_height, new_width))
#
# # 调整点的坐标
# selected_point[:, 1] -= min_x - border_x
# selected_point[:, 0] -= min_y - border_y
#
# plt.figure(figsize=(new_width / 100, new_height / 100)) # 单位为英寸100是默认的dpi
# plt.imshow(bg, cmap='gray')
#
# for point in selected_point:
# # print(points[i])
# # if labels[i] == 0:
# plt.scatter(point[1], point[0], s=18, c='white', zorder=2)
# # elif labels[i] == 1 or 2:
# # plt.scatter(points[i][1], points[i][0], s=18, c='yellow', zorder=2)
plt.axis('off')
plt.tight_layout()
# plt.tight_layout()
plt.savefig('2_.png')
plt.show()
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
print(filename)
json_file = os.path.join(folder_path, filename)
if filename.endswith('.jpg'):
print(filename)
img_file = os.path.join(folder_path, filename)
process_json_file(json_file,img_file)