atom-predict/egnn_v2/metricse2e_vor_confusion.py

88 lines
2.3 KiB
Python

import os
import cv2
import glob
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from egnn_utils.e2e_metrics import get_metrics
from egnn_core.data import get_y_3
from egnn_core.data import load_data
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix
class_dict = {
1: 'Norm',
2: 'SV',
3: 'LineSV',
}
class_dict_rev = {
'Norm': 1,
'SV': 2,
'LineSV': 3,
}
# with open('/cc/data/end-to-end-result-sv/gnn_data/test-on-end-to-end.json') as f:
# data = json.load(f)
#
# name = np.array(data['name'])
# pred = np.argmax(np.array(data['pred']), axis=1)
# pred_dict = dict(zip(name, pred))
#
# json_lst = glob.glob('/cc/data/end-to-end-result-sv/gnn_data/after_test2/*.json', recursive=True); len(json_lst)
#
# for json_path in tqdm(json_lst):
# base_name = json_path.split('/')[-1].split('.')[0]
# points, edge_index, _, _ = load_data(json_path)
# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])
#
# with open(json_path) as f:
# data = json.load(f)
#
# for i in range(len(labels)):
# data['shapes'][i]['label'] = class_dict[labels[i] + 1]
#
# with open(json_path, 'w') as f:
# json.dump(data, f)
json_lst = glob.glob('/home/gao/mouclear/cc/final_todo/SV/test/after_test/*.json', recursive=True);
len(json_lst)
count = 0
confusion_list = []
for json_path in tqdm(json_lst):
confusion_list.append(np.empty((3, 3)))
base_name = json_path.split('/')[-1].split('.')[0]
points, edge_index, labels, _ = load_data(json_path)
mask_pd = np.zeros((2048, 2048))
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)
for i in range(1, 4):
for j in range(1, 4):
confusion_list[count][i-1,j-1] = get_metrics(mask_gt == i, mask_pd == j)[1]
count += 1
#
print(confusion_list)
# #
# res = np.array(res)
# #
# #
# # Norm
# print(np.mean(res[::3, :], axis=0))
#
# # SV
# print(np.mean(res[1::3, :], axis=0))
#
# # LineSV
# print(np.mean(res[2::3, :], axis=0))
#
# print(np.mean(res, axis=0))