atom-predict/egnn_v2/metricse2e_vor.py

149 lines
4.6 KiB
Python

import os
import shutil
import glob
import json
import numpy as np
from tqdm import tqdm
from egnn_core.data import load_data
from model_type_dict import class_dict, cz_class_dict, xj_class_dict, smov_class_dict, norm_line_sv_label, \
cz_label, xj_label, smov_label
def get_class_dict(model_type):
if model_type == cz_label:
return cz_class_dict
elif model_type == xj_label:
return xj_class_dict
elif model_type == smov_label:
return smov_class_dict
elif model_type == norm_line_sv_label:
return class_dict
else:
raise ValueError('Invalid model type')
def find_and_copy_image(json_path, save_path):
# 提取 json_path 所在的目录
directory = os.path.dirname(json_path)
# 提取文件名(不带扩展名)
base_name = os.path.splitext(os.path.basename(json_path))[0]
# 构造要查找的文件路径模式
jpg_pattern = os.path.join(directory, f'{base_name}.jpg')
# 查找与 json_path 文件名相同的 .jpg 文件
jpg_files = glob.glob(jpg_pattern)
# 如果找到匹配的 .jpg 文件,则复制到 save_path
for jpg_file in jpg_files:
# 目标路径
destination = os.path.join(save_path, os.path.basename(jpg_file))
shutil.copy(jpg_file, destination)
print(f"Copied {jpg_file} to {destination}")
def post_process(predict_result, test_path, save_path, model_type='norm_sv_line'):
os.makedirs(save_path, exist_ok=True)
with open(predict_result) as f:
data = json.load(f)
name = np.array(data['name'])
pred = np.argmax(np.array(data['pred']), axis=1)
pred_dict = dict(zip(name, pred))
json_lst = glob.glob(os.path.join(test_path, '**', '*.json'), recursive=True);
len(json_lst)
for json_path in tqdm(json_lst):
base_name = json_path.split('/')[-1].split('.')[0]
points, edge_index, gt_label, _ = load_data(json_path)
labels = np.array(
[pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])
with open(json_path) as f:
data = json.load(f)
class_name = get_class_dict(model_type)
print(f"model_type:{model_type}, class_dict:{class_name}")
for i in range(len(labels)):
# if model_type == xj_label:
# index = labels[i]
# else:
index = labels[i]
data['shapes'][i]['label'] = class_name[index]
relative_path = os.path.relpath(json_path, test_path)
target_path = os.path.join(save_path, relative_path)
target_dir = os.path.dirname(target_path)
os.makedirs(target_dir, exist_ok=True)
with open(target_path, 'w') as f:
json.dump(data, f)
find_and_copy_image(json_path, save_path)
# predict_result = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/egnn_v2/logs/0/version_10/test.json'
# test_path = './gnn_sv_train/test/'
# save_path = './predictddd_result/post_process/'
#
# post_process(predict_result, test_path, save_path)
# with open('/home/deploy/script/bohrim-app/bohrium-app/pythonProject/egnn_v2/logs/0/version_10/test.json') as f:
# data = json.load(f)
#
# name = np.array(data['name'])
# pred = np.argmax(np.array(data['pred']), axis=1)
# pred_dict = dict(zip(name, pred))
#
# json_lst = glob.glob('./gnn_sv_train/test/*.json', recursive=True); len(json_lst)
#
# for json_path in tqdm(json_lst):
# base_name = json_path.split('/')[-1].split('.')[0]
# points, edge_index, gt_label, _ = load_data(json_path)
# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])
#
# with open(json_path) as f:
# data = json.load(f)
#
# for i in range(len(labels)):
# data['shapes'][i]['label'] = class_dict[labels[i] + 1]
#
# with open(json_path, 'w') as f:
# json.dump(data, f)
# #
# json_lst = glob.glob('/home/gao/mouclear/cc/data_new/gnn_sv/after_test/raw/*.json', recursive=True);
# len(json_lst)
# res = []
#
# for json_path in tqdm(json_lst):
# base_name = json_path.split('/')[-1].split('.')[0]
# points, edge_index, labels, _ = load_data(json_path)
#
# mask_pd = np.zeros((2048, 2048))
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
# mask_pd = np.array(mask_pd, np.uint8)
#
# mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)
#
# for i in range(1, 4):
# res += [get_metrics(mask_gt == i, mask_pd == i)]
#
# res = np.array(res)
#
#
# # Norm
# print(np.mean(res[::3, :], axis=0))
#
# # SV
# print(np.mean(res[1::3, :], axis=0))
#
# # LineSV
# print(np.mean(res[2::3, :], axis=0))
#
# print(np.mean(res, axis=0))