atom-predict/egnn_v2/egnn_core/data.py

613 lines
19 KiB
Python

import cv2
import copy
import glob
import json
import torch
import numpy as np
from PIL import Image
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, download_url, Dataset
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
from model_type_dict import norm_line_sv_label
import albumentations as A
try:
A.check_version()
except Exception as e:
print(f"Version check failed: {e}")
def get_training_augmentation():
train_transform = [
A.GaussianBlur(),
A.RandomBrightnessContrast(),
A.Normalize(),
]
return A.Compose(train_transform)
def get_validation_augmentation(model_type=None):
if model_type == norm_line_sv_label:
print("线缺陷检测模型需要归一化")
test_transform = [
A.Normalize(),
]
else:
print("其他模型不需要归一化")
test_transform = [
]
return A.Compose(test_transform)
def get_y_3(y, edge_index):
conn_matrix = np.zeros((len(y), len(y)))
for s, e in edge_index.transpose():
conn_matrix[s, e] = 1
conn_matrix[e, s] = 1
y_sv_norm = np.array(y != 0, np.uint8)
y_pred = np.array(y != 0, np.uint8)
for i in range(len(y_pred)):
if y_pred[i] != 0:
try:
n1 = np.where(conn_matrix[i] != 0)[0]
n2 = np.concatenate([np.where(conn_matrix[item] != 0)[0] for item in n1])
ner = np.concatenate([n1, n2])
ner = set(ner)
ner.remove(i)
ner = list(ner)
if np.max(y_sv_norm[ner]) > 0:
y_pred[i] = 2
except:
None
else:
continue
return y_pred
def get_light(img, point, ps=1):
img_h, img_w = img.shape
h, w = point
hs = np.clip(h-ps, 0, img_h-2*ps-1)
ws = np.clip(w-ps, 0, img_w-2*ps-1)
he = hs + ps*2 + 1
we = ws + ps*2 + 1
return img[hs:he, ws:we].flatten()
def get_edge_index_delaunay(points, max_edge_length=35):
tri = Delaunay(points)
edge_index = set()
for simplex in tri.simplices:
for i in range(3):
for j in range(i + 1, 3):
edge = tuple(sorted([simplex[i], simplex[j]]))
edge_length = np.linalg.norm(points[edge[0]] - points[edge[1]])
if edge_length <= max_edge_length:
edge_index.add(edge)
edge_index = np.array(list(edge_index)).T
edge_str = []
edge_index = np.sort(edge_index, axis=0)
for item in edge_index.transpose():
edge_str += ['{}_{}'.format(item[0], item[1])]
edge_str = list(set(edge_str))
s = [int(item.split('_')[0]) for item in edge_str]
e = [int(item.split('_')[1]) for item in edge_str]
edge_index = np.array([s + e, e + s])
return edge_index
def get_edge_index(points, lights, max_ner=3, dis_thres=35, light_thres=120):
s = []
t = []
connected_id = []
for i, p in enumerate(points):
p = p.reshape(1, -1)
xysub = np.abs(p - points)
diss = np.sqrt(xysub[:, 0] ** 2 + xysub[:, 1] ** 2)
ner_idxs = np.argsort(diss)[1: 1+max_ner].tolist()
for j, item in enumerate(ner_idxs):
if item in connected_id:
ner_idxs.pop(j)
ner_idxs = np.array(ner_idxs)
ner_diss = diss[ner_idxs]
ner_final_idxs = ner_idxs[ner_diss < dis_thres].tolist()
if np.mean(lights[i]) > light_thres:
s += [i] * len(ner_final_idxs)
t += ner_final_idxs
connected_id += [i]
edge_index = np.array([s, t])
edge_str = []
edge_index = np.sort(edge_index, axis=0)
for item in edge_index.transpose():
edge_str += ['{}_{}'.format(item[0], item[1])]
edge_str = list(set(edge_str))
s = [int(item.split('_')[0]) for item in edge_str]
e = [int(item.split('_')[1]) for item in edge_str]
edge_index = np.array([s+e, e+s])
return edge_index
def load_data(json_path):
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index(points, lights)
return points, edge_index, labels, lights
def load_data_v2(json_path):
# data_dict = {
# 'V': 0,
# 'S': 1,
# 'Mo': 2,
# }
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
}
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
'atom': 0,
'cz': 1,
'0': 0,
'1': 1,
'2': 2,
'S': 0,
'Mo': 1,
'V': 2,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, max_edge_length=35.3)
#line 35
return points, edge_index, labels, lights
def load_data_v2_aug(json_path, aug, edges_length=35.5):
# data_dict = {
# 'V': 0,
# 'S': 1,
# 'Mo': 2,
# }
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
}
img_path = json_path.replace('.json', '.jpg')
print("img_path:", img_path)
image = Image.open(img_path)
# if image.mode != 'RGB':
# image = image.convert('RGB')
# 应用数据增强管道到图像
img = aug(image=np.array(image))['image']
# img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
# mean = (0.485, 0.456, 0.406)
# std = (0.229, 0.224, 0.225)
#
# # 反归一化图像
# img = denormalize(auto_img, mean, std)
# if len(auto_img.shape) == 3:
# # 如果是彩色图像,转换为灰度图像
# print("RGB convert gray")
# img = np.mean(auto_img, axis=2)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, edges_length)
#line 35
return points, edge_index, labels, lights
def load_data_v3(json_path):
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
'center': 3,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, max_edge_length=36)
return points, edge_index, labels, lights
def load_data_v4(json_path):
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
'center': 3,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index(points, lights)
return points, edge_index, labels, lights
def load_data_v5(json_path):
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
'graph': 3,
'graphcenter': 4,
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, max_edge_length=35)
return points, edge_index, labels, lights
def load_data_v6(json_path):
data_dict = {
'Norm': 0,
'SV': 1,
'LineSV': 2,
'graph': 3,
'graphcenter': 4,
'vacancy':5
}
img_path = json_path.replace('.json', '.jpg')
img = np.array(Image.open(img_path))
# img = cv2.equalizeHist(img)
# img = cv2.GaussianBlur(img, (5, 5), 0)
with open(json_path) as f:
json_data = json.load(f)
points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32)
try:
labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8)
except:
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, max_edge_length=35)
return points, edge_index, labels, lights
def find_ner(idx, edge_index):
ner = list(set(edge_index[1][edge_index[0] == idx].tolist() + edge_index[0][edge_index[1] == idx].tolist()))
return ner
def get_adj_matrix(n_nodes):
edges_dic = {}
if n_nodes in edges_dic:
edges_dic_b = edges_dic[n_nodes]
if batch_size in edges_dic_b:
return edges_dic_b[batch_size]
else:
# get edges for a single sample
rows, cols = [], []
for batch_idx in range(batch_size):
for i in range(n_nodes):
for j in range(n_nodes):
rows.append(i + batch_idx*n_nodes)
cols.append(j + batch_idx*n_nodes)
else:
edges_dic[n_nodes] = {}
return get_adj_matrix(n_nodes, batch_size, device)
edges = [torch.LongTensor(rows).to(device), torch.LongTensor(cols)]
return edges
class AtomDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(AtomDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['atom.dataset']
def download(self):
pass
def process(self):
json_lst = glob.glob('{}/raw/*.json'.format(self.root), recursive=True)
print('Json data number: {}.'.format(len(json_lst)))
data_lst = []
for json_path in json_lst:
base_name = json_path.split('/')[-1].split('.')[0]
points, edge_index, labels, lights = load_data(json_path)
for idx, center in enumerate(points):
ner1 = find_ner(idx, edge_index)
ner2 = []
if len(ner1) != 0:
ner2 = [find_ner(item, edge_index) for item in ner1]
ner2 = [item2 for item1 in ner2 for item2 in item1]
ner2 = list(set(ner2))
if idx in ner2:
ner2.remove(idx)
ner = [idx] + list(set(ner1 + ner2))
atom_nums = len(ner)
sample_light = torch.FloatTensor(lights[ner]) / 255.
sample_pos = torch.FloatTensor((points[ner] - np.mean(points[ner], axis=0)) / 12.)
sample_label = torch.tensor([labels[idx]] + [0]*(atom_nums - 1), dtype=torch.int64)
sample_mask = torch.tensor([1] + [0]*(atom_nums - 1), dtype=torch.bool)
sample_edge_index = torch.LongTensor(np.array([(i, j) for i in range(atom_nums) for j in range(atom_nums)]).transpose())
sample_data = Data(
pos = sample_pos,
light = sample_light,
label = sample_label,
mask = sample_mask,
edge_index = sample_edge_index,
name = '{}_{}_{}'.format(base_name, center[0], center[1])
)
data_lst += [sample_data]
data, slices = self.collate(data_lst)
torch.save((data, slices), self.processed_paths[0])
class AtomDataset_v2(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(AtomDataset_v2, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['atom.dataset']
def download(self):
pass
def process(self):
json_lst = glob.glob('{}/raw/*.json'.format(self.root), recursive=True)
print('Json data number: {}.'.format(len(json_lst)))
data_lst = []
for json_path in json_lst:
base_name = json_path.split('/')[-1].split('.')[0]
points, edge_index, labels, lights = load_data_v2(json_path)
for idx, center in enumerate(points):
ner1 = find_ner(idx, edge_index)
ner2 = []
if len(ner1) != 0:
ner2 = [find_ner(item, edge_index) for item in ner1]
ner2 = [item2 for item1 in ner2 for item2 in item1]
ner2 = list(set(ner2))
if idx in ner2:
ner2.remove(idx)
ner = [idx] + list(set(ner1 + ner2))
atom_nums = len(ner)
sample_light = torch.FloatTensor(lights[ner]) / 255.
sample_pos = torch.FloatTensor((points[ner] - np.mean(points[ner], axis=0)) / 12.)
sample_label = torch.tensor([labels[idx]] + [0] * (atom_nums - 1), dtype=torch.int64)
sample_mask = torch.tensor([1] + [0] * (atom_nums - 1), dtype=torch.bool)
sample_edge_index = torch.LongTensor(
np.array([(i, j) for i in range(atom_nums) for j in range(atom_nums)]).transpose())
sample_data = Data(
pos=sample_pos,
light=sample_light,
label=sample_label,
mask=sample_mask,
edge_index=sample_edge_index,
name='{}_{}_{}'.format(base_name, center[0], center[1])
)
data_lst += [sample_data]
data, slices = self.collate(data_lst)
torch.save((data, slices), self.processed_paths[0])
class AtomDataset_v2_aug(InMemoryDataset):
def __init__(self, root, edges_length=35.5, model_type=None, transform=None, pre_transform=None):
self.edges_length = edges_length
self.model_type = model_type
super(AtomDataset_v2_aug, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@property
def processed_file_names(self):
return ['atom.dataset']
def download(self):
pass
def process(self):
json_lst = glob.glob('{}/*.json'.format(self.root), recursive=True)
print('Json data number: {}.'.format(len(json_lst)))
print("edges:", self.edges_length)
data_lst = []
# for i in range(10):
for json_path in json_lst:
base_name = json_path.split('/')[-1].split('.')[0]
aug = get_training_augmentation() if json_path.split('/')[-2] == 'train' else get_validation_augmentation(self.model_type)
points, edge_index, labels, lights = load_data_v2_aug(json_path,aug, self.edges_length)
for idx, center in enumerate(points):
ner1 = find_ner(idx, edge_index)
ner2 = []
if len(ner1) != 0:
ner2 = [find_ner(item, edge_index) for item in ner1]
ner2 = [item2 for item1 in ner2 for item2 in item1]
ner2 = list(set(ner2))
if idx in ner2:
ner2.remove(idx)
ner = [idx] + list(set(ner1 + ner2))
atom_nums = len(ner)
sample_light = torch.FloatTensor(lights[ner]) / 255.
sample_pos = torch.FloatTensor((points[ner] - np.mean(points[ner], axis=0)) / 12.)
sample_label = torch.tensor([labels[idx]] + [0] * (atom_nums - 1), dtype=torch.int64)
sample_mask = torch.tensor([1] + [0] * (atom_nums - 1), dtype=torch.bool)
sample_edge_index = torch.LongTensor(
np.array([(i, j) for i in range(atom_nums) for j in range(atom_nums)]).transpose())
sample_data = Data(
pos=sample_pos,
light=sample_light,
label=sample_label,
mask=sample_mask,
edge_index=sample_edge_index,
name='{}_{}_{}'.format(base_name, center[0], center[1])
)
data_lst += [sample_data]
data, slices = self.collate(data_lst)
torch.save((data, slices), self.processed_paths[0])