149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
import os
|
|
import shutil
|
|
|
|
import glob
|
|
import json
|
|
import numpy as np
|
|
|
|
from tqdm import tqdm
|
|
from egnn_core.data import load_data
|
|
|
|
from model_type_dict import class_dict, cz_class_dict, xj_class_dict, smov_class_dict, norm_line_sv_label, \
|
|
cz_label, xj_label, smov_label
|
|
|
|
def get_class_dict(model_type):
|
|
if model_type == cz_label:
|
|
return cz_class_dict
|
|
elif model_type == xj_label:
|
|
return xj_class_dict
|
|
elif model_type == smov_label:
|
|
return smov_class_dict
|
|
elif model_type == norm_line_sv_label:
|
|
return class_dict
|
|
else:
|
|
raise ValueError('Invalid model type')
|
|
|
|
|
|
def find_and_copy_image(json_path, save_path):
|
|
# 提取 json_path 所在的目录
|
|
directory = os.path.dirname(json_path)
|
|
|
|
# 提取文件名(不带扩展名)
|
|
base_name = os.path.splitext(os.path.basename(json_path))[0]
|
|
|
|
# 构造要查找的文件路径模式
|
|
jpg_pattern = os.path.join(directory, f'{base_name}.jpg')
|
|
|
|
# 查找与 json_path 文件名相同的 .jpg 文件
|
|
jpg_files = glob.glob(jpg_pattern)
|
|
|
|
# 如果找到匹配的 .jpg 文件,则复制到 save_path
|
|
for jpg_file in jpg_files:
|
|
# 目标路径
|
|
destination = os.path.join(save_path, os.path.basename(jpg_file))
|
|
shutil.copy(jpg_file, destination)
|
|
print(f"Copied {jpg_file} to {destination}")
|
|
|
|
def post_process(predict_result, test_path, save_path, model_type='norm_sv_line'):
|
|
os.makedirs(save_path, exist_ok=True)
|
|
with open(predict_result) as f:
|
|
data = json.load(f)
|
|
|
|
name = np.array(data['name'])
|
|
pred = np.argmax(np.array(data['pred']), axis=1)
|
|
pred_dict = dict(zip(name, pred))
|
|
|
|
json_lst = glob.glob(os.path.join(test_path, '**', '*.json'), recursive=True);
|
|
len(json_lst)
|
|
|
|
for json_path in tqdm(json_lst):
|
|
base_name = json_path.split('/')[-1].split('.')[0]
|
|
points, edge_index, gt_label, _ = load_data(json_path)
|
|
labels = np.array(
|
|
[pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])
|
|
|
|
with open(json_path) as f:
|
|
data = json.load(f)
|
|
|
|
class_name = get_class_dict(model_type)
|
|
print(f"model_type:{model_type}, class_dict:{class_name}")
|
|
for i in range(len(labels)):
|
|
# if model_type == xj_label:
|
|
# index = labels[i]
|
|
# else:
|
|
index = labels[i]
|
|
data['shapes'][i]['label'] = class_name[index]
|
|
|
|
relative_path = os.path.relpath(json_path, test_path)
|
|
target_path = os.path.join(save_path, relative_path)
|
|
target_dir = os.path.dirname(target_path)
|
|
os.makedirs(target_dir, exist_ok=True)
|
|
|
|
with open(target_path, 'w') as f:
|
|
json.dump(data, f)
|
|
|
|
find_and_copy_image(json_path, save_path)
|
|
|
|
|
|
|
|
|
|
# predict_result = '/home/deploy/script/bohrim-app/bohrium-app/pythonProject/egnn_v2/logs/0/version_10/test.json'
|
|
# test_path = './gnn_sv_train/test/'
|
|
# save_path = './predictddd_result/post_process/'
|
|
#
|
|
# post_process(predict_result, test_path, save_path)
|
|
|
|
|
|
# with open('/home/deploy/script/bohrim-app/bohrium-app/pythonProject/egnn_v2/logs/0/version_10/test.json') as f:
|
|
# data = json.load(f)
|
|
#
|
|
# name = np.array(data['name'])
|
|
# pred = np.argmax(np.array(data['pred']), axis=1)
|
|
# pred_dict = dict(zip(name, pred))
|
|
#
|
|
# json_lst = glob.glob('./gnn_sv_train/test/*.json', recursive=True); len(json_lst)
|
|
#
|
|
# for json_path in tqdm(json_lst):
|
|
# base_name = json_path.split('/')[-1].split('.')[0]
|
|
# points, edge_index, gt_label, _ = load_data(json_path)
|
|
# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])
|
|
#
|
|
# with open(json_path) as f:
|
|
# data = json.load(f)
|
|
#
|
|
# for i in range(len(labels)):
|
|
# data['shapes'][i]['label'] = class_dict[labels[i] + 1]
|
|
#
|
|
# with open(json_path, 'w') as f:
|
|
# json.dump(data, f)
|
|
# #
|
|
# json_lst = glob.glob('/home/gao/mouclear/cc/data_new/gnn_sv/after_test/raw/*.json', recursive=True);
|
|
# len(json_lst)
|
|
# res = []
|
|
#
|
|
# for json_path in tqdm(json_lst):
|
|
# base_name = json_path.split('/')[-1].split('.')[0]
|
|
# points, edge_index, labels, _ = load_data(json_path)
|
|
#
|
|
# mask_pd = np.zeros((2048, 2048))
|
|
# mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
|
# mask_pd = np.array(mask_pd, np.uint8)
|
|
#
|
|
# mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)
|
|
#
|
|
# for i in range(1, 4):
|
|
# res += [get_metrics(mask_gt == i, mask_pd == i)]
|
|
#
|
|
# res = np.array(res)
|
|
#
|
|
#
|
|
# # Norm
|
|
# print(np.mean(res[::3, :], axis=0))
|
|
#
|
|
# # SV
|
|
# print(np.mean(res[1::3, :], axis=0))
|
|
#
|
|
# # LineSV
|
|
# print(np.mean(res[2::3, :], axis=0))
|
|
#
|
|
# print(np.mean(res, axis=0)) |