atom-predict/egnn_v2/.ipynb_checkpoints/Metrics-checkpoint.ipynb

522 lines
15 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f7589eab-423f-48be-88cc-96348b018bc7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image\n",
"from utils.e2e_metrics import get_metrics\n",
"from core.data import get_y_3\n",
"from core.data import load_data\n",
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
]
},
{
"cell_type": "markdown",
"id": "ae5eee3b-f14c-474b-85df-a1b9ade172b5",
"metadata": {},
"source": [
"# Data Vis"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6f40a974-c629-4ddd-a580-a858fafe0be8",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import numpy as np\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"\n",
"def graph_showing(data):\n",
" '''\n",
" args:\n",
" data: torch_geometric.data.Data\n",
" '''\n",
" G = nx.Graph()\n",
" edge_index = data['edge_index'].t()\n",
" edge_index = np.array(edge_index.cpu())\n",
" \n",
" G.add_edges_from(edge_index)\n",
" nx.draw(G)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "d4da7d46-2f19-43b2-ad79-216b8ee124b3",
"metadata": {
"tags": []
},
"source": [
"# Stage Score"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "b5399705-4d50-4589-ad8b-2e4212465c2f",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"# with open('./logs/TEM/version_0/test.json') as f:\n",
"# data = json.load(f)\n",
"\n",
"# name = np.array(data['name'])\n",
"# pred = np.array(data['pred'])\n",
"# label = np.array(data['label'])\n",
"\n",
"# pred = np.argmax(pred, axis=1)\n",
"# label[label != 0] = 1\n",
"# slides = np.array([int(item.split('_')[0]) for item in name])\n",
"# idx = [item not in [4, 8] for item in slides]\n",
"# lb = label[idx]\n",
"# pd = pred[idx]\n",
"# accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd), confusion_matrix(lb, pd)\n",
"\n",
"# for item in set(slides):\n",
"# if item in [4, 8]:\n",
"# continue\n",
"# idx = slides == item\n",
"# lb = label[idx]\n",
"# pd = pred[idx]\n",
"# print(item, accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8a45ddee-174c-4ae1-b34b-4940e7e558ae",
"metadata": {},
"outputs": [],
"source": [
"# with open('./logs/TEM/version_0/test.json') as f:\n",
"# data = json.load(f)\n",
"\n",
"# name = np.array(data['name'])\n",
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
"# pred_dict = dict(zip(name, pred))\n",
"\n",
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)\n",
"\n",
"# ress = []\n",
"# for json_path in json_lst:\n",
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
"# if base_name in ['4', '8']:\n",
"# continue\n",
" \n",
"# points, edge_index, labels, lights = load_data(json_path)\n",
"# mask_gt = np.zeros((2048, 2048))\n",
"# mask_gt[points[:, 0], points[:, 1]] = labels + 1\n",
" \n",
"# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
"# labels = get_y_3(labels, edge_index)\n",
" \n",
"# mask_pd = np.zeros((2048, 2048))\n",
"# mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
" \n",
"# h, w = np.where(mask_pd != 0)\n",
"# pd = mask_pd[h, w]\n",
"# gt = mask_gt[h, w]\n",
" \n",
"# res = [precision_score(gt, pd, average='macro'), recall_score(gt, pd, average='macro'), f1_score(gt, pd, average='macro')]\n",
" \n",
"# ress += [res]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "a874300a-7d54-4ca1-88cf-06a86004386d",
"metadata": {},
"outputs": [],
"source": [
"# np.mean(ress, axis=0)"
]
},
{
"cell_type": "markdown",
"id": "cd11d150-56d8-433b-bb86-16d7d61ae2c9",
"metadata": {},
"source": [
"# Stage Vis"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "090cfe10-e86a-4f81-9d8f-97872677e837",
"metadata": {},
"outputs": [],
"source": [
"# def plot_json(json_path):\n",
"# colors = ['red', 'yellow', 'blue']\n",
" \n",
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
"# img_path = json_path.replace('.json', '.jpg')\n",
"# img = np.array(Image.open(img_path))\n",
"# points, edge_index, labels, lights = load_data(json_path)\n",
"# mask_gt = np.zeros((2048, 2048))\n",
"# mask_gt[points[:, 0], points[:, 1]] = labels + 1\n",
" \n",
"# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
"# labels = get_y_3(labels, edge_index)\n",
"\n",
"# sp = points[edge_index[0].astype(np.int32)]\n",
"# tp = points[edge_index[1].astype(np.int32)]\n",
" \n",
"# plt.figure(figsize=(24, 9))\n",
"\n",
"# plt.subplot(1, 3, 1)\n",
"# plt.imshow(img, cmap='gray')\n",
"# plt.axis('off')\n",
"# plt.title('Image')\n",
"\n",
"# plt.subplot(1, 3, 2)\n",
"# plt.imshow(img, cmap='gray')\n",
"# for i in [0, 1, 2]:\n",
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=30, zorder=2)\n",
"# plt.axis('off')\n",
"# plt.title('Points_Pred_'+base_name)\n",
"\n",
"# plt.subplot(1, 3, 3)\n",
"# bg = np.zeros((2048, 2048)) + 255\n",
"# bg[0, 0] = 0\n",
"# plt.imshow(bg, cmap='gray')\n",
"# # plt.imshow(img, cmap='gray')\n",
"# for i in range(len(sp)):\n",
"# plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=1, c='green', zorder=1)\n",
" \n",
"# plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
"# for i in [0, 1, 2]:\n",
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=5, zorder=2)\n",
"# plt.axis('off')\n",
"# plt.title('Edge_Pred_'+base_name)\n",
"# plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5c33b139-b84b-44ce-aed7-e4c8f7613853",
"metadata": {},
"outputs": [],
"source": [
"# with open('./logs/TEM/version_0/test.json') as f:\n",
"# data = json.load(f)\n",
"\n",
"# name = np.array(data['name'])\n",
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
"# pred_dict = dict(zip(name, pred))\n",
"\n",
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "2b08db8e-57b4-4513-9045-7ae6cdec04fd",
"metadata": {},
"outputs": [],
"source": [
"# plot_json(json_lst[8])"
]
},
{
"cell_type": "markdown",
"id": "5f26f3be-f279-4b3d-8e39-3457424ddb4f",
"metadata": {},
"source": [
"# Stage TSNE"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "b840655a-015c-4453-9256-afe36e08cce5",
"metadata": {},
"outputs": [],
"source": [
"# import torch\n",
"# import numpy as np\n",
"# import torch.nn as nn\n",
"# import collections\n",
"# from core.model import GNN\n",
"# from core.data import AtomDataset\n",
"# from sklearn.manifold import TSNE\n",
"# import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1ddbfef7-57c4-46fd-a32c-bd82ad468b50",
"metadata": {},
"outputs": [],
"source": [
"# ckpt = torch.load('./logs/TEM/version_0/checkpoints/epoch=17-val_loss=0.03-val_acc=0.99.ckpt')['state_dict']\n",
"# ckpt = collections.OrderedDict([(k.replace('model.', ''), v) for k, v in ckpt.items()])\n",
"# model = GNN()\n",
"# model.load_state_dict(ckpt, strict=False);\n",
"# model.fc = nn.Identity()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "75c9970d-a815-49c0-acce-b1de3d59458d",
"metadata": {},
"outputs": [],
"source": [
"# ds = AtomDataset('../../data/gnn_data/test/')"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "1ce37478-d35a-4deb-897d-326b45f7e0bc",
"metadata": {},
"outputs": [],
"source": [
"# model.eval()\n",
"# gts = []\n",
"# features = []\n",
"# for data in ds:\n",
"# features += model(data.x, data.edge_index, None).detach().cpu().numpy().tolist()\n",
"# gts += [int(data.y[0])]\n",
"\n",
"# features = np.array(features)\n",
"# gts = np.array(gts)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "1d6df9cc-6c32-4393-8bef-04ce11145f1d",
"metadata": {},
"outputs": [],
"source": [
"# tsne = TSNE(n_components=2).fit_transform(features)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f5d9cd1a-b77c-45b4-b076-d14d635e23c5",
"metadata": {},
"outputs": [],
"source": [
"# colors = ['green', 'red', 'blue']\n",
"# c = [colors[item] for item in gts]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "fc71be3d-8da1-4d5d-92c0-ed5cfe03f2e2",
"metadata": {},
"outputs": [],
"source": [
"# plt.scatter(tsne[:, 0], tsne[:, 1], c=gts)\n",
"# plt.axis('off')"
]
},
{
"cell_type": "markdown",
"id": "15626028-66f3-4157-aca2-81be60c7ebe2",
"metadata": {},
"source": [
"# E2E Score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "92628e75-bb7d-450b-bf7e-d1493b79e479",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"6 (0.9776635698754694, 0.9546419610113878, 0.9660156249999999)\n",
"6 (0.024734982332155476, 0.06086956521739131, 0.035175879396984924)\n",
"6 (0.0, 0.0, 0)\n",
"2 (0.9947694691979853, 0.9890215716486903, 0.991887193355225)\n",
"2 (0.008658008658008658, 0.015037593984962405, 0.010989010989010988)\n",
"2 (0.0070921985815602835, 0.0043859649122807015, 0.0054200542005420045)\n"
]
}
],
"source": [
"with open('./logs/0/version_0/e2e.json') as f:\n",
" data = json.load(f)\n",
"\n",
"name = np.array(data['name'])\n",
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
"pred_dict = dict(zip(name, pred))\n",
"\n",
"json_lst = glob.glob('../../data/gnn_data/e2e/raw/*.json', recursive=True)\n",
"\n",
"for json_path in json_lst:\n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" points, edge_index, _, _ = load_data(json_path)\n",
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
" labels = get_y_3(labels, edge_index)\n",
" \n",
" mask_pd = np.zeros((2048, 2048))\n",
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
" \n",
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')))\n",
" \n",
" for i in range(1, 4):\n",
" print(base_name, get_metrics(mask_gt == i, mask_pd == i))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d432c4b0-9843-472b-91b3-ec57b076fa1f",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 17,
"id": "0dc4a5cd-645e-41b5-9e93-e6a43329c4a3",
"metadata": {},
"outputs": [],
"source": [
"# def plot3(json_path, _type='gt'):\n",
"# colors = ['red', 'yellow', 'blue']\n",
" \n",
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
"# img_path = json_path.replace('.json', '.jpg')\n",
"# img = np.array(Image.open(img_path))\n",
"\n",
"# points, edge_index, labels, _ = load_data(json_path)\n",
"# # labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
"# # labels = get_y_3(labels, edge_index)\n",
" \n",
"# sp = points[edge_index[0].astype(np.int32)]\n",
"# tp = points[edge_index[1].astype(np.int32)]\n",
" \n",
"# plt.figure(figsize=(24, 9))\n",
"\n",
"# plt.subplot(1, 3, 1)\n",
"# plt.imshow(img, cmap='gray')\n",
"# plt.axis('off')\n",
"# plt.title('Image')\n",
"\n",
"# plt.subplot(1, 3, 2)\n",
"# plt.imshow(img, cmap='gray')\n",
"# for i in [0, 1, 2]:\n",
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=30, zorder=2)\n",
"# plt.axis('off')\n",
"# plt.title('Points_Pred_'+base_name)\n",
"\n",
"# plt.subplot(1, 3, 3)\n",
"# bg = np.zeros((2048, 2048)) + 255\n",
"# bg[0, 0] = 0\n",
"# plt.imshow(bg, cmap='gray')\n",
"# # plt.imshow(img, cmap='gray')\n",
"# for i in range(len(sp)):\n",
"# plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=1, c='green', zorder=1)\n",
" \n",
"# plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
"# for i in [0, 1, 2]:\n",
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=5, zorder=2)\n",
"# plt.axis('off')\n",
"# plt.title('Edge_Pred_'+base_name)\n",
"# plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "cfd26d35-fc84-4cbf-bb76-382dcc23c497",
"metadata": {},
"outputs": [],
"source": [
"# with open('./logs/TEM/version_0/e2e.json') as f:\n",
"# data = json.load(f)\n",
"\n",
"# name = np.array(data['name'])\n",
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
"# pred_dict = dict(zip(name, pred))\n",
"\n",
"# json_lst = glob.glob('../../data/gnn_data/e2e/raw/*.json', recursive=True)\n",
"# json_path = json_lst[7]\n",
"# plot3(json_path)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "3032d9f2-68e6-44c5-be1d-0f31887e8b72",
"metadata": {},
"outputs": [],
"source": [
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "679a32db-5a80-41ca-acf3-e2cfec5d6d0c",
"metadata": {},
"outputs": [],
"source": [
"# json_path = json_lst[7]\n",
"# plot3(json_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "07faa19f-7b77-48b3-858f-27eb98f35633",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}