atom-predict/msunet/MetricsE2E.ipynb

289 lines
12 KiB
Plaintext
Executable File

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b7f09d7e-dc1e-4962-a031-0a55e5b67a90",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-26T12:45:45.001534Z",
"start_time": "2024-06-26T12:45:43.574935Z"
}
},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'get_y_3' from 'core.data' (/home/gao/mouclear/cc/code/msunet/core/data.py)",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[1], line 12\u001B[0m\n\u001B[1;32m 10\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mPIL\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Image\n\u001B[1;32m 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01me2e_metrics\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_metrics\n\u001B[0;32m---> 12\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcore\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_y_3\n\u001B[1;32m 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcore\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m load_data\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mmetrics\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m accuracy_score, f1_score, recall_score, precision_score, confusion_matrix\n",
"\u001B[0;31mImportError\u001B[0m: cannot import name 'get_y_3' from 'core.data' (/home/gao/mouclear/cc/code/msunet/core/data.py)"
]
}
],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tqdm import tqdm\n",
"from PIL import Image\n",
"from utils.e2e_metrics import get_metrics\n",
"from core.data import get_y_3\n",
"from core.data import load_data\n",
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
]
},
{
"cell_type": "markdown",
"id": "bde4590a-e868-4668-8b88-9b7ae6741c02",
"metadata": {},
"source": [
"# Update Class"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7aaf0964-d720-4d61-bf66-4e21d58d8c9c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-26T12:50:38.236947Z",
"start_time": "2024-06-26T12:50:38.176286Z"
}
},
"outputs": [
{
"ename": "ImportError",
"evalue": "cannot import name 'load_data' from 'core.data' (/home/gao/mouclear/cc/code/msunet/core/data.py)",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mImportError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[6], line 13\u001B[0m\n\u001B[1;32m 11\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01me2e_metrics\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m get_metrics\n\u001B[1;32m 12\u001B[0m \u001B[38;5;66;03m# from core.data import get_y_3\u001B[39;00m\n\u001B[0;32m---> 13\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mcore\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mdata\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m load_data\n\u001B[1;32m 14\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01msklearn\u001B[39;00m\u001B[38;5;21;01m.\u001B[39;00m\u001B[38;5;21;01mmetrics\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m accuracy_score, f1_score, recall_score, precision_score, confusion_matrix\n\u001B[1;32m 16\u001B[0m class_dict \u001B[38;5;241m=\u001B[39m {\n\u001B[1;32m 17\u001B[0m \u001B[38;5;241m1\u001B[39m: \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mNorm\u001B[39m\u001B[38;5;124m'\u001B[39m, \n\u001B[1;32m 18\u001B[0m \u001B[38;5;241m2\u001B[39m: \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mSV\u001B[39m\u001B[38;5;124m'\u001B[39m,\n\u001B[1;32m 19\u001B[0m \u001B[38;5;241m3\u001B[39m: \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mLineSV\u001B[39m\u001B[38;5;124m'\u001B[39m,\n\u001B[1;32m 20\u001B[0m }\n",
"\u001B[0;31mImportError\u001B[0m: cannot import name 'load_data' from 'core.data' (/home/gao/mouclear/cc/code/msunet/core/data.py)"
]
}
],
"source": [
"class_dict = {\n",
" 1: 'Norm', \n",
" 2: 'SV',\n",
" 3: 'LineSV',\n",
"}\n",
"\n",
"class_dict_rev = {\n",
" 'Norm': 1, \n",
" 'SV': 2,\n",
" 'LineSV': 3,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "41873800-6b74-4fa4-8c8b-3cecd7089518",
"metadata": {},
"outputs": [],
"source": [
"with open('/home/gao/mouclear/cc/code/egnn_jj/logs/0/version_0/e2e.json') as f:\n",
" data = json.load(f)\n",
"\n",
"name = np.array(data['name'])\n",
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
"pred_dict = dict(zip(name, pred))\n",
"\n",
"json_lst = glob.glob('/home/gao/mouclear/cc/data/jj/e2e_result/*.json', recursive=True); len(json_lst)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b002854a-b36e-4423-a016-dd088a344681",
"metadata": {},
"outputs": [],
"source": [
"for json_path in tqdm(json_lst):\n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" points, edge_index, _, _ = load_data(json_path)\n",
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
" \n",
" with open(json_path) as f:\n",
" data = json.load(f)\n",
"\n",
" for i in range(len(labels)):\n",
" data['shapes'][i]['label'] = class_dict[labels[i] + 1]\n",
" \n",
" with open(json_path, 'w') as f:\n",
" json.dump(data, f)"
]
},
{
"cell_type": "markdown",
"id": "ac107f45-f949-4fcd-8950-9235b23d07ba",
"metadata": {},
"source": [
"# Metrics"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "139d5824-cf33-45d4-8fce-e525245295ea",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8ccaf1cb-0202-4f92-9757-ba2eaf6be30a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-26T12:50:14.026715Z",
"start_time": "2024-06-26T12:50:13.967685Z"
}
},
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '/home/gao/mouclear/cc/code/egnn_jj/logs/0/version_0/test.json'",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[4], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[38;5;124;43m/home/gao/mouclear/cc/code/egnn_jj/logs/0/version_0/test.json\u001B[39;49m\u001B[38;5;124;43m'\u001B[39;49m\u001B[43m)\u001B[49m \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[1;32m 2\u001B[0m data \u001B[38;5;241m=\u001B[39m json\u001B[38;5;241m.\u001B[39mload(f)\n\u001B[1;32m 3\u001B[0m label \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39marray(data[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mlabel\u001B[39m\u001B[38;5;124m'\u001B[39m]) \u001B[38;5;66;03m# [metric_idx]\u001B[39;00m\n",
"File \u001B[0;32m~/anaconda3/envs/moc/lib/python3.8/site-packages/IPython/core/interactiveshell.py:284\u001B[0m, in \u001B[0;36m_modified_open\u001B[0;34m(file, *args, **kwargs)\u001B[0m\n\u001B[1;32m 277\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m file \u001B[38;5;129;01min\u001B[39;00m {\u001B[38;5;241m0\u001B[39m, \u001B[38;5;241m1\u001B[39m, \u001B[38;5;241m2\u001B[39m}:\n\u001B[1;32m 278\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[1;32m 279\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mIPython won\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt let you open fd=\u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfile\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m by default \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 280\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mas it is likely to crash IPython. If you know what you are doing, \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 281\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124myou can use builtins\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m open.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[1;32m 282\u001B[0m )\n\u001B[0;32m--> 284\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mio_open\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43margs\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
"\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: '/home/gao/mouclear/cc/code/egnn_jj/logs/0/version_0/test.json'"
]
}
],
"source": [
"json_lst = glob.glob('/home/gao/mouclear/cc/data/jj/e2e_result/*.json', recursive=True); len(json_lst)\n",
"res = []\n",
"\n",
"for json_path in tqdm(json_lst):\n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" points, edge_index, labels, _ = load_data(json_path)\n",
" \n",
" mask_pd = np.zeros((2048, 2048))\n",
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
" mask_pd = np.array(mask_pd, np.uint8)\n",
" \n",
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)\n",
" \n",
" # for i in range(1, 4):\n",
" # res += [get_metrics(mask_gt == i, mask_pd == i)]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cca8debf-9de3-4473-a7eb-bd3ca8884a2e",
"metadata": {},
"outputs": [],
"source": [
"res = np.array(res)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5d92bb3b-e3c2-4373-8ae1-2bbdc901476b",
"metadata": {},
"outputs": [],
"source": [
"# Norm\n",
"print(np.mean(res[::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "995b9bab-768c-47e2-9199-91fa3572380b",
"metadata": {},
"outputs": [],
"source": [
"# SV\n",
"print(np.mean(res[1::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34a00da7-8576-4102-a3a8-e78e15db0d45",
"metadata": {},
"outputs": [],
"source": [
"# LineSV\n",
"print(np.mean(res[2::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4926aea8-b9d2-4878-b8db-5f676d067bc1",
"metadata": {},
"outputs": [],
"source": [
"print(np.mean(res, axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24043672-21c7-40ef-91a1-5c612cfd6c78",
"metadata": {},
"outputs": [],
"source": [
"# train+test with old v1:\n",
"0.994232\n",
"0.995057\n",
"0.994643\n",
"\n",
"0.863394\n",
"0.898949\n",
"0.878925\n",
"\n",
"0.910954\n",
"0.906328\n",
"0.908198\n",
"\n",
"0.922860\n",
"0.933445\n",
"0.927255\n",
"\n",
"# train+test with old v2:\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}