308 lines
15 KiB
Plaintext
308 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "b7f09d7e-dc1e-4962-a031-0a55e5b67a90",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T13:09:30.892710Z",
|
|
"start_time": "2024-06-14T13:09:29.658550Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import cv2\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"from PIL import Image\n",
|
|
"from utils.e2e_metrics import get_metrics\n",
|
|
"from core.data import get_y_3\n",
|
|
"from core.data import load_data\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bde4590a-e868-4668-8b88-9b7ae6741c02",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Update Class"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "7aaf0964-d720-4d61-bf66-4e21d58d8c9c",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T13:09:30.895429Z",
|
|
"start_time": "2024-06-14T13:09:30.893492Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class_dict = {\n",
|
|
" 1: 'Norm', \n",
|
|
" 2: 'SV',\n",
|
|
" 3: 'LineSV',\n",
|
|
"}\n",
|
|
"\n",
|
|
"class_dict_rev = {\n",
|
|
" 'Norm': 1, \n",
|
|
" 'SV': 2,\n",
|
|
" 'LineSV': 3,\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "41873800-6b74-4fa4-8c8b-3cecd7089518",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T13:09:34.949279Z",
|
|
"start_time": "2024-06-14T13:09:30.895911Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"ename": "KeyError",
|
|
"evalue": "'name'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
|
"\u001B[0;31mKeyError\u001B[0m Traceback (most recent call last)",
|
|
"Cell \u001B[0;32mIn[3], line 4\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mopen\u001B[39m(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m/home/gao/mouclear/cc/data/end-to-end-result/patch_unet/test/end-to-end-test.json\u001B[39m\u001B[38;5;124m'\u001B[39m) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[1;32m 2\u001B[0m data \u001B[38;5;241m=\u001B[39m json\u001B[38;5;241m.\u001B[39mload(f)\n\u001B[0;32m----> 4\u001B[0m name \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(\u001B[43mdata\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43mname\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m]\u001B[49m)\n\u001B[1;32m 5\u001B[0m pred \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39margmax(np\u001B[38;5;241m.\u001B[39marray(data[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mpred\u001B[39m\u001B[38;5;124m'\u001B[39m]), axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m)\n\u001B[1;32m 6\u001B[0m pred_dict \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mdict\u001B[39m(\u001B[38;5;28mzip\u001B[39m(name, pred))\n",
|
|
"\u001B[0;31mKeyError\u001B[0m: 'name'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with open('/home/gao/mouclear/cc/data/end-to-end-result/patch_unet/test/end-to-end-test.json') as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
"name = np.array(data['name'])\n",
|
|
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"json_lst = glob.glob('../../data/end-to-end-result/gnn_data/e2e/raw/*.json', recursive=True); len(json_lst)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "b002854a-b36e-4423-a016-dd088a344681",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"for json_path in tqdm(json_lst):\n",
|
|
" base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
" points, edge_index, _, _ = load_data(json_path)\n",
|
|
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
" \n",
|
|
" with open(json_path) as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
" for i in range(len(labels)):\n",
|
|
" data['shapes'][i]['label'] = class_dict[labels[i] + 1]\n",
|
|
" \n",
|
|
" with open(json_path, 'w') as f:\n",
|
|
" json.dump(data, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ac107f45-f949-4fcd-8950-9235b23d07ba",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "139d5824-cf33-45d4-8fce-e525245295ea",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "8ccaf1cb-0202-4f92-9757-ba2eaf6be30a",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-19T15:38:31.790346Z",
|
|
"start_time": "2024-06-19T15:37:50.968183Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0%| | 0/4 [00:37<?, ?it/s]\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
|
"Cell \u001B[0;32mIn[5], line 29\u001B[0m\n\u001B[1;32m 25\u001B[0m mask_pd \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(mask_pd, np\u001B[38;5;241m.\u001B[39muint8)\n\u001B[1;32m 27\u001B[0m mask_gt \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(Image\u001B[38;5;241m.\u001B[39mopen(json_path\u001B[38;5;241m.\u001B[39mreplace(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.json\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.png\u001B[39m\u001B[38;5;124m'\u001B[39m)), np\u001B[38;5;241m.\u001B[39muint8)\n\u001B[0;32m---> 29\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28;43mrange\u001B[39;49m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m4\u001B[39m):\n\u001B[1;32m 30\u001B[0m res \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m [get_metrics(mask_gt \u001B[38;5;241m==\u001B[39m i, mask_pd \u001B[38;5;241m==\u001B[39m i)]\n",
|
|
"Cell \u001B[0;32mIn[5], line 29\u001B[0m\n\u001B[1;32m 25\u001B[0m mask_pd \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(mask_pd, np\u001B[38;5;241m.\u001B[39muint8)\n\u001B[1;32m 27\u001B[0m mask_gt \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(Image\u001B[38;5;241m.\u001B[39mopen(json_path\u001B[38;5;241m.\u001B[39mreplace(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.json\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.png\u001B[39m\u001B[38;5;124m'\u001B[39m)), np\u001B[38;5;241m.\u001B[39muint8)\n\u001B[0;32m---> 29\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m i \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28;43mrange\u001B[39;49m(\u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m4\u001B[39m):\n\u001B[1;32m 30\u001B[0m res \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m [get_metrics(mask_gt \u001B[38;5;241m==\u001B[39m i, mask_pd \u001B[38;5;241m==\u001B[39m i)]\n",
|
|
"File \u001B[0;32m~/下载/pycharm-professional-2023.3.5/pycharm-2023.3.5/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_frame.py:755\u001B[0m, in \u001B[0;36mPyDBFrame.trace_dispatch\u001B[0;34m(self, frame, event, arg)\u001B[0m\n\u001B[1;32m 753\u001B[0m \u001B[38;5;66;03m# if thread has a suspend flag, we suspend with a busy wait\u001B[39;00m\n\u001B[1;32m 754\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m info\u001B[38;5;241m.\u001B[39mpydev_state \u001B[38;5;241m==\u001B[39m STATE_SUSPEND:\n\u001B[0;32m--> 755\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdo_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m)\u001B[49m\n\u001B[1;32m 756\u001B[0m \u001B[38;5;66;03m# No need to reset frame.f_trace to keep the same trace function.\u001B[39;00m\n\u001B[1;32m 757\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mtrace_dispatch\n",
|
|
"File \u001B[0;32m~/下载/pycharm-professional-2023.3.5/pycharm-2023.3.5/plugins/python/helpers/pydev/_pydevd_bundle/pydevd_frame.py:412\u001B[0m, in \u001B[0;36mPyDBFrame.do_wait_suspend\u001B[0;34m(self, *args, **kwargs)\u001B[0m\n\u001B[1;32m 411\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mdo_wait_suspend\u001B[39m(\u001B[38;5;28mself\u001B[39m, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 412\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_args\u001B[49m\u001B[43m[\u001B[49m\u001B[38;5;241;43m0\u001B[39;49m\u001B[43m]\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mdo_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
|
|
"File \u001B[0;32m~/下载/pycharm-professional-2023.3.5/pycharm-2023.3.5/plugins/python/helpers/pydev/pydevd.py:1184\u001B[0m, in \u001B[0;36mPyDB.do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, send_suspend_message, is_unhandled_exception)\u001B[0m\n\u001B[1;32m 1181\u001B[0m from_this_thread\u001B[38;5;241m.\u001B[39mappend(frame_id)\n\u001B[1;32m 1183\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_threads_suspended_single_notification\u001B[38;5;241m.\u001B[39mnotify_thread_suspended(thread_id, stop_reason):\n\u001B[0;32m-> 1184\u001B[0m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_do_wait_suspend\u001B[49m\u001B[43m(\u001B[49m\u001B[43mthread\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mframe\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mevent\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43marg\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43msuspend_type\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mfrom_this_thread\u001B[49m\u001B[43m)\u001B[49m\n",
|
|
"File \u001B[0;32m~/下载/pycharm-professional-2023.3.5/pycharm-2023.3.5/plugins/python/helpers/pydev/pydevd.py:1199\u001B[0m, in \u001B[0;36mPyDB._do_wait_suspend\u001B[0;34m(self, thread, frame, event, arg, suspend_type, from_this_thread)\u001B[0m\n\u001B[1;32m 1196\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_call_mpl_hook()\n\u001B[1;32m 1198\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mprocess_internal_commands()\n\u001B[0;32m-> 1199\u001B[0m \u001B[43mtime\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msleep\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;241;43m0.01\u001B[39;49m\u001B[43m)\u001B[49m\n\u001B[1;32m 1201\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mcancel_async_evaluation(get_current_thread_id(thread), \u001B[38;5;28mstr\u001B[39m(\u001B[38;5;28mid\u001B[39m(frame)))\n\u001B[1;32m 1203\u001B[0m \u001B[38;5;66;03m# process any stepping instructions\u001B[39;00m\n",
|
|
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import os\n",
|
|
"import cv2\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"from PIL import Image\n",
|
|
"from utils.e2e_metrics import get_metrics\n",
|
|
"from core.data import get_y_3\n",
|
|
"from core.data import load_data\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix\n",
|
|
"\n",
|
|
"json_lst = glob.glob('/home/gao/mouclear/cc/data/end-to-end-result/gnn_data/after_test/*.json', recursive=True); len(json_lst)\n",
|
|
"res = []\n",
|
|
"\n",
|
|
"for json_path in tqdm(json_lst):\n",
|
|
" base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
" points, edge_index, labels, _ = load_data(json_path)\n",
|
|
" \n",
|
|
" mask_pd = np.zeros((2048, 2048))\n",
|
|
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" mask_pd = np.array(mask_pd, np.uint8)\n",
|
|
" \n",
|
|
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)\n",
|
|
" \n",
|
|
" for i in range(1, 4):\n",
|
|
" res += [get_metrics(mask_gt == i, mask_pd == i)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cca8debf-9de3-4473-a7eb-bd3ca8884a2e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = np.array(res)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "5d92bb3b-e3c2-4373-8ae1-2bbdc901476b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Norm\n",
|
|
"print(np.mean(res[::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "995b9bab-768c-47e2-9199-91fa3572380b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# SV\n",
|
|
"print(np.mean(res[1::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "34a00da7-8576-4102-a3a8-e78e15db0d45",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# LineSV\n",
|
|
"print(np.mean(res[2::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4926aea8-b9d2-4878-b8db-5f676d067bc1",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(np.mean(res, axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "24043672-21c7-40ef-91a1-5c612cfd6c78",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# train+test with old v1:\n",
|
|
"0.994232\n",
|
|
"0.995057\n",
|
|
"0.994643\n",
|
|
"\n",
|
|
"0.863394\n",
|
|
"0.898949\n",
|
|
"0.878925\n",
|
|
"\n",
|
|
"0.910954\n",
|
|
"0.906328\n",
|
|
"0.908198\n",
|
|
"\n",
|
|
"0.922860\n",
|
|
"0.933445\n",
|
|
"0.927255\n",
|
|
"\n",
|
|
"# train+test with old v2:\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.18"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|