atom-predict/msunet/.ipynb_checkpoints/E2EMetricsRecall-checkpoint...

434 lines
12 KiB
Plaintext
Executable File

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "8710f835-b782-45de-bd4b-9cf0a7d51783",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import scipy\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image"
]
},
{
"cell_type": "markdown",
"id": "efd0eb6d-67fc-44dc-a4f7-c3e9abc0c6d7",
"metadata": {},
"source": [
"# Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c0f0914a-9ae0-4263-a30b-1646c2710975",
"metadata": {},
"outputs": [],
"source": [
"atom_labels = {\n",
" 'Norm': 0,\n",
" 'LineSV': 1,\n",
"}\n",
"\n",
"atom_labelsv = {\n",
" 0: 'Norm',\n",
" 1: 'LineSV',\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2bb17487-ad25-4739-9b4d-fa4feddd4422",
"metadata": {},
"outputs": [],
"source": [
"def get_gt(json_path):\n",
" with open(json_path) as f:\n",
" lbl = json.load(f)\n",
"\n",
" mask = np.zeros((lbl['imageHeight'], lbl['imageWidth']))\n",
"\n",
" points = np.array(np.round([item['points'][0][::-1] for item in lbl['shapes']]), np.int16)\n",
" # labels = [atom_labels[item] for item in np.array([item['label'] for item in lbl['shapes']])]\n",
" labels = [1 for item in np.array([item['label'] for item in lbl['shapes']])]\n",
"\n",
" mask[points[:, 0], points[:, 1]] = labels\n",
" \n",
" return mask"
]
},
{
"cell_type": "markdown",
"id": "d4417c7e-bf73-48e9-a3be-72a1d2965d9e",
"metadata": {},
"source": [
"# VIS"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "51d980f1-2964-48c9-8876-b750863c85ac",
"metadata": {},
"outputs": [],
"source": [
"# name = '15'\n",
"# pred = cv2.imread('../../v15_Final/data/infer/slide_pred/{}.png'.format(name), 0)\n",
"# img = cv2.imread('../data/slide/{}.jpg'.format(name), 0)\n",
"# h, w = np.where(pred != 0)\n",
"# plt.imshow(img, cmap='gray')\n",
"# plt.scatter(w, h, 1)"
]
},
{
"cell_type": "markdown",
"id": "7369e0d2-fd85-4661-b2aa-65dd369465d3",
"metadata": {},
"source": [
"# Metrics"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8337e6bf-d151-43fc-9f35-87ffed9107ac",
"metadata": {},
"outputs": [],
"source": [
"def center_coords_to_bbox(gt_coord):\n",
" box_rwidth, box_rheight = 10, 10\n",
" gt_bbox = (\n",
" gt_coord[0] - box_rwidth,\n",
" gt_coord[0] + box_rwidth + 1,\n",
" gt_coord[1] - box_rheight,\n",
" gt_coord[1] + box_rheight + 1\n",
" )\n",
" return gt_bbox"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "5dae72c5-9b8d-4461-935e-63690cfb2bd6",
"metadata": {},
"outputs": [],
"source": [
"def get_coord_to_bboxes(gt_coordinates_dict):\n",
" gt_bboxes_list = []\n",
" for gt_coord in gt_coordinates_dict:\n",
" gt_bbox = center_coords_to_bbox(gt_coord)\n",
" gt_bboxes_list.append(gt_bbox)\n",
" return gt_bboxes_list"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "f7b75735-4f74-4922-84f0-55b25431dbdc",
"metadata": {},
"outputs": [],
"source": [
"def bbox_iou(bb1, bb2):\n",
" assert bb1[0] <= bb1[1]\n",
" assert bb1[2] <= bb1[3]\n",
" assert bb2[0] <= bb2[1]\n",
" assert bb2[2] <= bb2[3]\n",
"\n",
" # determine the coordinates of the intersection rectangle\n",
" x_left = max(bb1[0], bb2[0])\n",
" y_top = max(bb1[2], bb2[2])\n",
" x_right = min(bb1[1], bb2[1])\n",
" y_bottom = min(bb1[3], bb2[3])\n",
"\n",
" if x_right < x_left or y_bottom < y_top:\n",
" return 0.0\n",
"\n",
" # The intersection of two axis-aligned bounding boxes is always an\n",
" # axis-aligned bounding box\n",
" intersection_area = (x_right - x_left) * (y_bottom - y_top)\n",
"\n",
" # compute the area of both AABBs\n",
" bb1_area = (bb1[1] - bb1[0]) * (bb1[3] - bb1[2])\n",
" bb2_area = (bb2[1] - bb2[0]) * (bb2[3] - bb2[2])\n",
"\n",
" # compute the intersection over union by taking the intersection\n",
" # area and dividing it by the sum of prediction + ground-truth\n",
" # areas - the interesection area\n",
" iou = intersection_area / float(bb1_area + bb2_area - intersection_area)\n",
" assert iou >= 0.0\n",
" assert iou <= 1.0\n",
" return iou"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "bfcba26a-f98b-4f91-84a0-3e15866ed682",
"metadata": {},
"outputs": [],
"source": [
"def match_bboxes(iou_matrix, IOU_THRESH=0.5):\n",
" n_true, n_pred = iou_matrix.shape\n",
" MIN_IOU = 0.0\n",
" MAX_DIST = 1.0\n",
"\n",
" if n_pred > n_true:\n",
" # there are more predictions than ground-truth - add dummy rows\n",
" diff = n_pred - n_true\n",
" iou_matrix = np.concatenate((iou_matrix,\n",
" np.full((diff, n_pred), MIN_IOU)),\n",
" axis=0)\n",
"\n",
" if n_true > n_pred:\n",
" # more ground-truth than predictions - add dummy columns\n",
" diff = n_true - n_pred\n",
" iou_matrix = np.concatenate((iou_matrix,\n",
" np.full((n_true, diff), MIN_IOU)),\n",
" axis=1)\n",
"\n",
" # call the Hungarian matching\n",
" idxs_true, idxs_pred = scipy.optimize.linear_sum_assignment(1 - iou_matrix)\n",
"\n",
" if (not idxs_true.size) or (not idxs_pred.size):\n",
" ious = np.array([])\n",
" else:\n",
" ious = iou_matrix[idxs_true, idxs_pred]\n",
"\n",
" # remove dummy assignments\n",
" sel_pred = idxs_pred < n_pred\n",
" idx_pred_actual = idxs_pred[sel_pred]\n",
" idx_gt_actual = idxs_true[sel_pred]\n",
" ious_actual = iou_matrix[idx_gt_actual, idx_pred_actual]\n",
" sel_valid = (ious_actual > IOU_THRESH)\n",
" label = sel_valid.astype(int)\n",
"\n",
" return idx_gt_actual[sel_valid], idx_pred_actual[sel_valid], ious_actual[sel_valid], label"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c99310be-0b42-46cb-a173-66d4a738f138",
"metadata": {},
"outputs": [],
"source": [
"def eval_matches(gt_bboxes, pd_bboxes, iou_threshold):\n",
" iou_matrix = np.zeros((len(gt_bboxes), len(pd_bboxes))).astype(np.float32)\n",
"\n",
" for gt_idx, gt_bbox in enumerate(gt_bboxes):\n",
" for pd_idx, pd_bbox in enumerate(pd_bboxes):\n",
" iou = bbox_iou(gt_bbox, pd_bbox)\n",
" iou_matrix[gt_idx, pd_idx] = iou\n",
" \n",
" idxs_true, idxs_pred, ious, labels = match_bboxes(iou_matrix, IOU_THRESH=iou_threshold)\n",
" return idxs_true, idxs_pred, ious, labels"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "11c49f5a-8ec8-4f93-8c25-e4cf58cea7ac",
"metadata": {},
"outputs": [],
"source": [
"def eval_metrics(n_matches, n_gt, n_pred):\n",
" precision = n_matches / n_pred if n_pred > 0 else 0.0\n",
" if n_gt == 0:\n",
" raise RuntimeError(\"No ground truth atoms???\")\n",
" recall = n_matches / n_gt\n",
" \n",
" return precision, recall"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "66f86103-4b56-4d62-a1da-f84931c9f99c",
"metadata": {},
"outputs": [],
"source": [
"def get_metrics(gt, pred, iou_threshold):\n",
" \n",
" h, w = np.where(gt != 0)\n",
" gt_coords = list(zip(h.flatten(), w.flatten()))\n",
" gt_bboxes = get_coord_to_bboxes(gt_coords)\n",
"\n",
" h, w = np.where(pred != 0)\n",
" pd_coords = list(zip(h.flatten(), w.flatten()))\n",
" pd_bboxes = get_coord_to_bboxes(pd_coords)\n",
"\n",
" idxs_true, idxs_pred, ious, labels = eval_matches(gt_bboxes, pd_bboxes, iou_threshold)\n",
" precision, recall = eval_metrics(n_matches=len(idxs_pred), n_gt=len(gt_coords), n_pred=len(pd_bboxes))\n",
" f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0\n",
" \n",
" return precision, recall, f1_score"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f7443718-9371-43ca-a74f-5fc0632e619f",
"metadata": {},
"outputs": [],
"source": [
"def get_mean_scores(iou_threshold=0.5):\n",
" results = []\n",
"\n",
" for base_name in base_names:\n",
"\n",
" gt = np.array(Image.open(os.path.join(gt_path, base_name+'.png')))\n",
" pred = np.array(Image.open(os.path.join(pd_path, base_name+'.png')))\n",
" \n",
" gt = np.array(gt > 1, np.uint8)\n",
" pred = np.array(pred != 0, np.uint8)\n",
"\n",
" results += [get_metrics(gt, pred, iou_threshold)]\n",
"\n",
" score = np.mean(results, axis=0)\n",
" return score\n",
" print('Precision: {}, Recall: {}, F1_score: {}.'.format(scores[0], scores[1], scores[2]))"
]
},
{
"cell_type": "markdown",
"id": "a4e11eaa-3e6a-4e8c-95ad-65a991876cad",
"metadata": {},
"source": [
"# TEST"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "f9c4dc7b-3b7b-4ea7-9ca1-6164f61d83f8",
"metadata": {},
"outputs": [],
"source": [
"gt_path = '../data/predsss/gt/'\n",
"pd_path = '../data/predsss/yolo/'"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "f74636a9-d2a9-4a8b-a511-13cf47a3f33a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"base_names = [item.split('/')[-1].split('.')[0] for item in glob.glob('../data/infer/slide_pred/*.png')]; len(base_names)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "b86d9715-be7a-48d3-8848-32f04444b62f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([0.05991716, 0.92995877, 0.11253439])"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"get_mean_scores(iou_threshold=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "681f731b-261a-4987-bc70-7a7b98170004",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c85eee2-6700-44cc-b0d6-9c6a618120d2",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "2208e32f-d2c5-4d1a-91ed-1092973e3069",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bd2462c-2355-40e9-9de1-5aae784d9865",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "cee3fa20-d898-42cf-8554-3f070838ffff",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "1af2ad1c-8b2e-4fb9-a5bc-20311c2ffdf0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}