atom-predict/egnn_v2/metric.py

36 lines
961 B
Python

import glob
import numpy as np
from tqdm import tqdm
from PIL import Image
from egnn_utils.e2e_metrics import get_metrics
from egnn_core.data import load_data
json_lst = glob.glob('/home/gao/mouclear/cc/final_todo/SV/test/after_test/*.json', recursive=True); len(json_lst)
res = []
for json_path in tqdm(json_lst):
print(json_path)
base_name = json_path.split('/')[-1].split('.')[0]
points, edge_index, labels, _ = load_data(json_path)
mask_pd = np.zeros((2048, 2048))
mask_pd[points[:, 0], points[:, 1]] = labels + 1
mask_pd = np.array(mask_pd, np.uint8)
mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)
for i in range(1, 4):
a = get_metrics(mask_gt == i, mask_pd == i)
print(a)
res += [a]
res = np.array(res)
# Norm
print(np.mean(res[::3, :], axis=0))
# SV
print(np.mean(res[1::3, :], axis=0))
# LineSV
print(np.mean(res[2::3, :], axis=0))
print(np.mean(res, axis=0))