384 lines
8.9 KiB
Plaintext
384 lines
8.9 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 37,
|
|
"id": "b7f09d7e-dc1e-4962-a031-0a55e5b67a90",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T01:59:49.688653Z",
|
|
"start_time": "2024-06-14T01:59:49.685813Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import cv2\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from tqdm import tqdm\n",
|
|
"from PIL import Image\n",
|
|
"from utils.e2e_metrics import get_metrics\n",
|
|
"from core.data import get_y_3\n",
|
|
"from core.data import load_data\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "bde4590a-e868-4668-8b88-9b7ae6741c02",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Update Class"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 38,
|
|
"id": "7aaf0964-d720-4d61-bf66-4e21d58d8c9c",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T01:59:49.692414Z",
|
|
"start_time": "2024-06-14T01:59:49.689344Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"class_dict = {\n",
|
|
" 1: 'Norm', \n",
|
|
" 2: 'SV',\n",
|
|
" 3: 'LineSV',\n",
|
|
"}\n",
|
|
"\n",
|
|
"class_dict_rev = {\n",
|
|
" 'Norm': 1, \n",
|
|
" 'SV': 2,\n",
|
|
" 'LineSV': 3,\n",
|
|
"}"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 39,
|
|
"id": "41873800-6b74-4fa4-8c8b-3cecd7089518",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T01:59:49.752172Z",
|
|
"start_time": "2024-06-14T01:59:49.697959Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "8"
|
|
},
|
|
"execution_count": 39,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"with open('./logs/0/version_4/test-on-new.json') as f:\n",
|
|
" data = json.load(f)\n",
|
|
" \n",
|
|
"name = np.array(data['name'])\n",
|
|
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"json_lst = glob.glob('../../data/new_v3/gnn_data/e2e/raw/*.json', recursive=True); len(json_lst)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 40,
|
|
"id": "b002854a-b36e-4423-a016-dd088a344681",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:00:02.566745Z",
|
|
"start_time": "2024-06-14T01:59:49.752986Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 8/8 [00:12<00:00, 1.60s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for json_path in tqdm(json_lst):\n",
|
|
" base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
" points, edge_index, _, _ = load_data(json_path)\n",
|
|
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
" \n",
|
|
" with open(json_path) as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
" for i in range(len(labels)):\n",
|
|
" data['shapes'][i]['label'] = class_dict[labels[i] + 1]\n",
|
|
" \n",
|
|
" with open(json_path, 'w') as f:\n",
|
|
" json.dump(data, f)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ac107f45-f949-4fcd-8950-9235b23d07ba",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Metrics"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 41,
|
|
"id": "139d5824-cf33-45d4-8fce-e525245295ea",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:00:02.570140Z",
|
|
"start_time": "2024-06-14T02:00:02.567358Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": "8"
|
|
},
|
|
"execution_count": 41,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"json_lst = glob.glob('../../data/new_v3/gnn_data/e2e/raw/*.json', recursive=True); len(json_lst)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 42,
|
|
"id": "8ccaf1cb-0202-4f92-9757-ba2eaf6be30a",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.787853Z",
|
|
"start_time": "2024-06-14T02:00:02.570861Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"100%|██████████| 8/8 [01:39<00:00, 12.40s/it]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"res = []\n",
|
|
"\n",
|
|
"for json_path in tqdm(json_lst):\n",
|
|
" base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
" points, edge_index, labels, _ = load_data(json_path)\n",
|
|
" \n",
|
|
" mask_pd = np.zeros((2048, 2048))\n",
|
|
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" mask_pd = np.array(mask_pd, np.uint8)\n",
|
|
" \n",
|
|
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)\n",
|
|
" \n",
|
|
" for i in range(1, 4):\n",
|
|
" res += [get_metrics(mask_gt == i, mask_pd == i)]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 43,
|
|
"id": "cca8debf-9de3-4473-a7eb-bd3ca8884a2e",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.790606Z",
|
|
"start_time": "2024-06-14T02:01:41.788736Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = np.array(res)\n",
|
|
"# print(res)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 44,
|
|
"id": "5d92bb3b-e3c2-4373-8ae1-2bbdc901476b",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.801101Z",
|
|
"start_time": "2024-06-14T02:01:41.791204Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.99696941 0.99299123 0.99497505]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Norm\n",
|
|
"print(np.mean(res[::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 45,
|
|
"id": "995b9bab-768c-47e2-9199-91fa3572380b",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.804774Z",
|
|
"start_time": "2024-06-14T02:01:41.801873Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.85331018 0.92309508 0.88642627]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# SV\n",
|
|
"print(np.mean(res[1::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 46,
|
|
"id": "34a00da7-8576-4102-a3a8-e78e15db0d45",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.808302Z",
|
|
"start_time": "2024-06-14T02:01:41.805334Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.91710062 0.94634315 0.93130157]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# LineSV\n",
|
|
"print(np.mean(res[2::3, :], axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 47,
|
|
"id": "4926aea8-b9d2-4878-b8db-5f676d067bc1",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.811675Z",
|
|
"start_time": "2024-06-14T02:01:41.809096Z"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[0.92246007 0.95414315 0.93756763]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(np.mean(res, axis=0))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 48,
|
|
"id": "24043672-21c7-40ef-91a1-5c612cfd6c78",
|
|
"metadata": {
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.814829Z",
|
|
"start_time": "2024-06-14T02:01:41.812133Z"
|
|
}
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"#old:\n",
|
|
"#array([0.9938805 , 0.99503497, 0.9944555 ])\n",
|
|
"#array([0.84130692, 0.908723 , 0.87341656])\n",
|
|
"#array([0.91153081, 0.89785063, 0.9044332 ])\n",
|
|
"#array([0.91557274, 0.93386954, 0.92410176])\n",
|
|
"\n",
|
|
"\n",
|
|
"#old_v2:\n",
|
|
"#array([0.998435 , 0.984716, 0.991523 ])\n",
|
|
"#array([0.811561, 0.900894 , 0.853540])\n",
|
|
"#array([0.823770, 0.949781, 0.881809 ])\n",
|
|
"#array([0.877922, 0.945130, 0.908957])\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"#new:\n",
|
|
"#array([0.992068 , 0.997752, 0.994902 ])\n",
|
|
"#array([0.897274, 0.817665 , 0.854897])\n",
|
|
"#array([0.956995, 0.910751, 0.933184 ])\n",
|
|
"#array([0.948779, 0.908723, 0.927661])\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"ExecuteTime": {
|
|
"end_time": "2024-06-14T02:01:41.816394Z",
|
|
"start_time": "2024-06-14T02:01:41.815288Z"
|
|
}
|
|
},
|
|
"id": "7e2470992e693e6e",
|
|
"execution_count": 48
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.18"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|