atom-predict/msunet/metric.py

44 lines
847 B
Python
Executable File

import json
import glob
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from multiprocessing import Pool
from utils.labelme import get_mask_v2
from utils.e2e_metrics import get_metrics
def get_score():
res = []
for id in range(nums):
res += [get_metrics(
label[id],
get_mask_v2(pred[id])
)]
return res
def get_score_v():
res = []
for id in range(nums):
res += [get_metrics(
np.array(label[id] > 0),
get_mask_v2(pred[id])
)]
return res
with open('/home/gao/mouclear/cc/code_v2/msunet/logs/0/version_0/test.json') as f:
data = json.load(f)
label = np.array(data['label']) # [metric_idx]
pred = np.array(data['pred']) # [metric_idx]
nums = label.shape[0]; nums
b = get_score_v()
print(np.mean(b, axis=0))