36 lines
961 B
Python
36 lines
961 B
Python
import glob
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from PIL import Image
|
|
from egnn_utils.e2e_metrics import get_metrics
|
|
from egnn_core.data import load_data
|
|
|
|
json_lst = glob.glob('/home/gao/mouclear/cc/final_todo/SV/test/after_test/*.json', recursive=True); len(json_lst)
|
|
res = []
|
|
|
|
for json_path in tqdm(json_lst):
|
|
print(json_path)
|
|
base_name = json_path.split('/')[-1].split('.')[0]
|
|
points, edge_index, labels, _ = load_data(json_path)
|
|
|
|
mask_pd = np.zeros((2048, 2048))
|
|
mask_pd[points[:, 0], points[:, 1]] = labels + 1
|
|
mask_pd = np.array(mask_pd, np.uint8)
|
|
|
|
mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)
|
|
|
|
for i in range(1, 4):
|
|
a = get_metrics(mask_gt == i, mask_pd == i)
|
|
print(a)
|
|
res += [a]
|
|
|
|
|
|
res = np.array(res)
|
|
# Norm
|
|
print(np.mean(res[::3, :], axis=0))
|
|
# SV
|
|
print(np.mean(res[1::3, :], axis=0))
|
|
# LineSV
|
|
print(np.mean(res[2::3, :], axis=0))
|
|
|
|
print(np.mean(res, axis=0)) |