434 lines
12 KiB
Plaintext
Executable File
434 lines
12 KiB
Plaintext
Executable File
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "8710f835-b782-45de-bd4b-9cf0a7d51783",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import cv2\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import scipy\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from PIL import Image"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "efd0eb6d-67fc-44dc-a4f7-c3e9abc0c6d7",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "c0f0914a-9ae0-4263-a30b-1646c2710975",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"atom_labels = {\n",
|
|
" 'Norm': 0,\n",
|
|
" 'LineSV': 1,\n",
|
|
"}\n",
|
|
"\n",
|
|
"atom_labelsv = {\n",
|
|
" 0: 'Norm',\n",
|
|
" 1: 'LineSV',\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "2bb17487-ad25-4739-9b4d-fa4feddd4422",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_gt(json_path):\n",
|
|
" with open(json_path) as f:\n",
|
|
" lbl = json.load(f)\n",
|
|
"\n",
|
|
" mask = np.zeros((lbl['imageHeight'], lbl['imageWidth']))\n",
|
|
"\n",
|
|
" points = np.array(np.round([item['points'][0][::-1] for item in lbl['shapes']]), np.int16)\n",
|
|
" # labels = [atom_labels[item] for item in np.array([item['label'] for item in lbl['shapes']])]\n",
|
|
" labels = [1 for item in np.array([item['label'] for item in lbl['shapes']])]\n",
|
|
"\n",
|
|
" mask[points[:, 0], points[:, 1]] = labels\n",
|
|
" \n",
|
|
" return mask"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4417c7e-bf73-48e9-a3be-72a1d2965d9e",
|
|
"metadata": {},
|
|
"source": [
|
|
"# VIS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "51d980f1-2964-48c9-8876-b750863c85ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# name = '15'\n",
|
|
"# pred = cv2.imread('../../v15_Final/data/infer/slide_pred/{}.png'.format(name), 0)\n",
|
|
"# img = cv2.imread('../data/slide/{}.jpg'.format(name), 0)\n",
|
|
"# h, w = np.where(pred != 0)\n",
|
|
"# plt.imshow(img, cmap='gray')\n",
|
|
"# plt.scatter(w, h, 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "7369e0d2-fd85-4661-b2aa-65dd369465d3",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "8337e6bf-d151-43fc-9f35-87ffed9107ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def center_coords_to_bbox(gt_coord):\n",
|
|
" box_rwidth, box_rheight = 10, 10\n",
|
|
" gt_bbox = (\n",
|
|
" gt_coord[0] - box_rwidth,\n",
|
|
" gt_coord[0] + box_rwidth + 1,\n",
|
|
" gt_coord[1] - box_rheight,\n",
|
|
" gt_coord[1] + box_rheight + 1\n",
|
|
" )\n",
|
|
" return gt_bbox"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "5dae72c5-9b8d-4461-935e-63690cfb2bd6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_coord_to_bboxes(gt_coordinates_dict):\n",
|
|
" gt_bboxes_list = []\n",
|
|
" for gt_coord in gt_coordinates_dict:\n",
|
|
" gt_bbox = center_coords_to_bbox(gt_coord)\n",
|
|
" gt_bboxes_list.append(gt_bbox)\n",
|
|
" return gt_bboxes_list"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "f7b75735-4f74-4922-84f0-55b25431dbdc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def bbox_iou(bb1, bb2):\n",
|
|
" assert bb1[0] <= bb1[1]\n",
|
|
" assert bb1[2] <= bb1[3]\n",
|
|
" assert bb2[0] <= bb2[1]\n",
|
|
" assert bb2[2] <= bb2[3]\n",
|
|
"\n",
|
|
" # determine the coordinates of the intersection rectangle\n",
|
|
" x_left = max(bb1[0], bb2[0])\n",
|
|
" y_top = max(bb1[2], bb2[2])\n",
|
|
" x_right = min(bb1[1], bb2[1])\n",
|
|
" y_bottom = min(bb1[3], bb2[3])\n",
|
|
"\n",
|
|
" if x_right < x_left or y_bottom < y_top:\n",
|
|
" return 0.0\n",
|
|
"\n",
|
|
" # The intersection of two axis-aligned bounding boxes is always an\n",
|
|
" # axis-aligned bounding box\n",
|
|
" intersection_area = (x_right - x_left) * (y_bottom - y_top)\n",
|
|
"\n",
|
|
" # compute the area of both AABBs\n",
|
|
" bb1_area = (bb1[1] - bb1[0]) * (bb1[3] - bb1[2])\n",
|
|
" bb2_area = (bb2[1] - bb2[0]) * (bb2[3] - bb2[2])\n",
|
|
"\n",
|
|
" # compute the intersection over union by taking the intersection\n",
|
|
" # area and dividing it by the sum of prediction + ground-truth\n",
|
|
" # areas - the interesection area\n",
|
|
" iou = intersection_area / float(bb1_area + bb2_area - intersection_area)\n",
|
|
" assert iou >= 0.0\n",
|
|
" assert iou <= 1.0\n",
|
|
" return iou"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "bfcba26a-f98b-4f91-84a0-3e15866ed682",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def match_bboxes(iou_matrix, IOU_THRESH=0.5):\n",
|
|
" n_true, n_pred = iou_matrix.shape\n",
|
|
" MIN_IOU = 0.0\n",
|
|
" MAX_DIST = 1.0\n",
|
|
"\n",
|
|
" if n_pred > n_true:\n",
|
|
" # there are more predictions than ground-truth - add dummy rows\n",
|
|
" diff = n_pred - n_true\n",
|
|
" iou_matrix = np.concatenate((iou_matrix,\n",
|
|
" np.full((diff, n_pred), MIN_IOU)),\n",
|
|
" axis=0)\n",
|
|
"\n",
|
|
" if n_true > n_pred:\n",
|
|
" # more ground-truth than predictions - add dummy columns\n",
|
|
" diff = n_true - n_pred\n",
|
|
" iou_matrix = np.concatenate((iou_matrix,\n",
|
|
" np.full((n_true, diff), MIN_IOU)),\n",
|
|
" axis=1)\n",
|
|
"\n",
|
|
" # call the Hungarian matching\n",
|
|
" idxs_true, idxs_pred = scipy.optimize.linear_sum_assignment(1 - iou_matrix)\n",
|
|
"\n",
|
|
" if (not idxs_true.size) or (not idxs_pred.size):\n",
|
|
" ious = np.array([])\n",
|
|
" else:\n",
|
|
" ious = iou_matrix[idxs_true, idxs_pred]\n",
|
|
"\n",
|
|
" # remove dummy assignments\n",
|
|
" sel_pred = idxs_pred < n_pred\n",
|
|
" idx_pred_actual = idxs_pred[sel_pred]\n",
|
|
" idx_gt_actual = idxs_true[sel_pred]\n",
|
|
" ious_actual = iou_matrix[idx_gt_actual, idx_pred_actual]\n",
|
|
" sel_valid = (ious_actual > IOU_THRESH)\n",
|
|
" label = sel_valid.astype(int)\n",
|
|
"\n",
|
|
" return idx_gt_actual[sel_valid], idx_pred_actual[sel_valid], ious_actual[sel_valid], label"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "c99310be-0b42-46cb-a173-66d4a738f138",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eval_matches(gt_bboxes, pd_bboxes, iou_threshold):\n",
|
|
" iou_matrix = np.zeros((len(gt_bboxes), len(pd_bboxes))).astype(np.float32)\n",
|
|
"\n",
|
|
" for gt_idx, gt_bbox in enumerate(gt_bboxes):\n",
|
|
" for pd_idx, pd_bbox in enumerate(pd_bboxes):\n",
|
|
" iou = bbox_iou(gt_bbox, pd_bbox)\n",
|
|
" iou_matrix[gt_idx, pd_idx] = iou\n",
|
|
" \n",
|
|
" idxs_true, idxs_pred, ious, labels = match_bboxes(iou_matrix, IOU_THRESH=iou_threshold)\n",
|
|
" return idxs_true, idxs_pred, ious, labels"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "11c49f5a-8ec8-4f93-8c25-e4cf58cea7ac",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def eval_metrics(n_matches, n_gt, n_pred):\n",
|
|
" precision = n_matches / n_pred if n_pred > 0 else 0.0\n",
|
|
" if n_gt == 0:\n",
|
|
" raise RuntimeError(\"No ground truth atoms???\")\n",
|
|
" recall = n_matches / n_gt\n",
|
|
" \n",
|
|
" return precision, recall"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "66f86103-4b56-4d62-a1da-f84931c9f99c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_metrics(gt, pred, iou_threshold):\n",
|
|
" \n",
|
|
" h, w = np.where(gt != 0)\n",
|
|
" gt_coords = list(zip(h.flatten(), w.flatten()))\n",
|
|
" gt_bboxes = get_coord_to_bboxes(gt_coords)\n",
|
|
"\n",
|
|
" h, w = np.where(pred != 0)\n",
|
|
" pd_coords = list(zip(h.flatten(), w.flatten()))\n",
|
|
" pd_bboxes = get_coord_to_bboxes(pd_coords)\n",
|
|
"\n",
|
|
" idxs_true, idxs_pred, ious, labels = eval_matches(gt_bboxes, pd_bboxes, iou_threshold)\n",
|
|
" precision, recall = eval_metrics(n_matches=len(idxs_pred), n_gt=len(gt_coords), n_pred=len(pd_bboxes))\n",
|
|
" f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0\n",
|
|
" \n",
|
|
" return precision, recall, f1_score"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "f7443718-9371-43ca-a74f-5fc0632e619f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def get_mean_scores(iou_threshold=0.5):\n",
|
|
" results = []\n",
|
|
"\n",
|
|
" for base_name in base_names:\n",
|
|
"\n",
|
|
" gt = np.array(Image.open(os.path.join(gt_path, base_name+'.png')))\n",
|
|
" pred = np.array(Image.open(os.path.join(pd_path, base_name+'.png')))\n",
|
|
" \n",
|
|
" gt = np.array(gt > 1, np.uint8)\n",
|
|
" pred = np.array(pred != 0, np.uint8)\n",
|
|
"\n",
|
|
" results += [get_metrics(gt, pred, iou_threshold)]\n",
|
|
"\n",
|
|
" score = np.mean(results, axis=0)\n",
|
|
" return score\n",
|
|
" print('Precision: {}, Recall: {}, F1_score: {}.'.format(scores[0], scores[1], scores[2]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a4e11eaa-3e6a-4e8c-95ad-65a991876cad",
|
|
"metadata": {},
|
|
"source": [
|
|
"# TEST"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"id": "f9c4dc7b-3b7b-4ea7-9ca1-6164f61d83f8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"gt_path = '../data/predsss/gt/'\n",
|
|
"pd_path = '../data/predsss/yolo/'"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"id": "f74636a9-d2a9-4a8b-a511-13cf47a3f33a",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"3"
|
|
]
|
|
},
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"base_names = [item.split('/')[-1].split('.')[0] for item in glob.glob('../data/infer/slide_pred/*.png')]; len(base_names)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"id": "b86d9715-be7a-48d3-8848-32f04444b62f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([0.05991716, 0.92995877, 0.11253439])"
|
|
]
|
|
},
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"get_mean_scores(iou_threshold=0.5)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "681f731b-261a-4987-bc70-7a7b98170004",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "8c85eee2-6700-44cc-b0d6-9c6a618120d2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "2208e32f-d2c5-4d1a-91ed-1092973e3069",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "9bd2462c-2355-40e9-9de1-5aae784d9865",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cee3fa20-d898-42cf-8554-3f070838ffff",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1af2ad1c-8b2e-4fb9-a5bc-20311c2ffdf0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "cmae",
|
|
"language": "python",
|
|
"name": "cmae"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|