atom-predict/egnn_v2/MetricsE2E_vor.ipynb

322 lines
8.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "b7f09d7e-dc1e-4962-a031-0a55e5b67a90",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:10:57.446546Z",
"start_time": "2024-06-22T13:10:56.276676Z"
}
},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from tqdm import tqdm\n",
"from PIL import Image\n",
"from utils.e2e_metrics import get_metrics\n",
"from core.data import get_y_3\n",
"from core.data import load_data\n",
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
]
},
{
"cell_type": "markdown",
"id": "bde4590a-e868-4668-8b88-9b7ae6741c02",
"metadata": {},
"source": [
"# Update Class"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7aaf0964-d720-4d61-bf66-4e21d58d8c9c",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:10:57.449661Z",
"start_time": "2024-06-22T13:10:57.447666Z"
}
},
"outputs": [],
"source": [
"class_dict = {\n",
" 1: 'Norm', \n",
" 2: 'SV',\n",
" 3: 'LineSV',\n",
"}\n",
"\n",
"class_dict_rev = {\n",
" 'Norm': 1, \n",
" 'SV': 2,\n",
" 'LineSV': 3,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "41873800-6b74-4fa4-8c8b-3cecd7089518",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:10:57.977015Z",
"start_time": "2024-06-22T13:10:57.450135Z"
}
},
"outputs": [
{
"data": {
"text/plain": "41"
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with open('/home/gao/mouclear/cc/code/egnn/logs/0/version_4/test-on-e2e.json') as f:\n",
" data = json.load(f)\n",
"\n",
"name = np.array(data['name'])\n",
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
"pred_dict = dict(zip(name, pred))\n",
"\n",
"json_lst = glob.glob('/home/gao/mouclear/cc/data/all_sv_e2e/sv/raw/raw/*.json', recursive=True); len(json_lst)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "b002854a-b36e-4423-a016-dd088a344681",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.565227Z",
"start_time": "2024-06-22T13:10:57.977878Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 41/41 [01:11<00:00, 1.75s/it]\n"
]
}
],
"source": [
"for json_path in tqdm(json_lst):\n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" points, edge_index, _, _ = load_data(json_path)\n",
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
" \n",
" with open(json_path) as f:\n",
" data = json.load(f)\n",
"\n",
" for i in range(len(labels)):\n",
" data['shapes'][i]['label'] = class_dict[labels[i] + 1]\n",
" \n",
" with open(json_path, 'w') as f:\n",
" json.dump(data, f)"
]
},
{
"cell_type": "markdown",
"id": "ac107f45-f949-4fcd-8950-9235b23d07ba",
"metadata": {},
"source": [
"# Metrics"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "139d5824-cf33-45d4-8fce-e525245295ea",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.567187Z",
"start_time": "2024-06-22T13:12:09.565874Z"
}
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "8ccaf1cb-0202-4f92-9757-ba2eaf6be30a",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.571148Z",
"start_time": "2024-06-22T13:12:09.567671Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"0it [00:00, ?it/s]\n"
]
}
],
"source": [
"json_lst = glob.glob('/home/gao/mouclear/cc/data/end-to-end-result/gnn_data/test_on_truth/raw/raw/*.json', recursive=True); len(json_lst)\n",
"res = []\n",
"\n",
"for json_path in tqdm(json_lst):\n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" points, edge_index, labels, _ = load_data(json_path)\n",
"\n",
" mask_pd = np.zeros((2048, 2048))\n",
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
" mask_pd = np.array(mask_pd, np.uint8)\n",
"\n",
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)\n",
"\n",
" for i in range(1, 4):\n",
" res += [get_metrics(mask_gt == i, mask_pd == i)]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cca8debf-9de3-4473-a7eb-bd3ca8884a2e",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.580502Z",
"start_time": "2024-06-22T13:12:09.571614Z"
}
},
"outputs": [],
"source": [
"res = np.array(res)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5d92bb3b-e3c2-4373-8ae1-2bbdc901476b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.730432Z",
"start_time": "2024-06-22T13:12:09.581013Z"
}
},
"outputs": [
{
"ename": "IndexError",
"evalue": "too many indices for array: array is 1-dimensional, but 2 were indexed",
"output_type": "error",
"traceback": [
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[0;31mIndexError\u001B[0m Traceback (most recent call last)",
"Cell \u001B[0;32mIn[7], line 2\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[38;5;66;03m# Norm\u001B[39;00m\n\u001B[0;32m----> 2\u001B[0m \u001B[38;5;28mprint\u001B[39m(np\u001B[38;5;241m.\u001B[39mmean(\u001B[43mres\u001B[49m\u001B[43m[\u001B[49m\u001B[43m:\u001B[49m\u001B[43m:\u001B[49m\u001B[38;5;241;43m3\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43m:\u001B[49m\u001B[43m]\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m))\n",
"\u001B[0;31mIndexError\u001B[0m: too many indices for array: array is 1-dimensional, but 2 were indexed"
]
}
],
"source": [
"# Norm\n",
"print(np.mean(res[::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "995b9bab-768c-47e2-9199-91fa3572380b",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.731220Z",
"start_time": "2024-06-22T13:12:09.731147Z"
}
},
"outputs": [],
"source": [
"# SV\n",
"print(np.mean(res[1::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34a00da7-8576-4102-a3a8-e78e15db0d45",
"metadata": {
"ExecuteTime": {
"end_time": "2024-06-22T13:12:09.731842Z",
"start_time": "2024-06-22T13:12:09.731701Z"
}
},
"outputs": [],
"source": [
"# LineSV\n",
"print(np.mean(res[2::3, :], axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4926aea8-b9d2-4878-b8db-5f676d067bc1",
"metadata": {},
"outputs": [],
"source": [
"print(np.mean(res, axis=0))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "24043672-21c7-40ef-91a1-5c612cfd6c78",
"metadata": {},
"outputs": [],
"source": [
"# train+test with vor:\n",
"0.994757\n",
"0.995859\n",
"0.995307\n",
"\n",
"0.930622\n",
"0.869687\n",
"0.898882\n",
"\n",
"0.911436\n",
"0.933655\n",
"0.922190\n",
"\n",
"0.945605\n",
"0.933067\n",
"0.938793\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}