522 lines
15 KiB
Plaintext
522 lines
15 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "f7589eab-423f-48be-88cc-96348b018bc7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"import cv2\n",
|
|
"import glob\n",
|
|
"import json\n",
|
|
"import numpy as np\n",
|
|
"import pandas as pd\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from PIL import Image\n",
|
|
"from utils.e2e_metrics import get_metrics\n",
|
|
"from core.data import get_y_3\n",
|
|
"from core.data import load_data\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "ae5eee3b-f14c-474b-85df-a1b9ade172b5",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Data Vis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "6f40a974-c629-4ddd-a580-a858fafe0be8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import numpy as np\n",
|
|
"import networkx as nx\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"def graph_showing(data):\n",
|
|
" '''\n",
|
|
" args:\n",
|
|
" data: torch_geometric.data.Data\n",
|
|
" '''\n",
|
|
" G = nx.Graph()\n",
|
|
" edge_index = data['edge_index'].t()\n",
|
|
" edge_index = np.array(edge_index.cpu())\n",
|
|
" \n",
|
|
" G.add_edges_from(edge_index)\n",
|
|
" nx.draw(G)\n",
|
|
" plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d4da7d46-2f19-43b2-ad79-216b8ee124b3",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# Stage Score"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "b5399705-4d50-4589-ad8b-2e4212465c2f",
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# with open('./logs/TEM/version_0/test.json') as f:\n",
|
|
"# data = json.load(f)\n",
|
|
"\n",
|
|
"# name = np.array(data['name'])\n",
|
|
"# pred = np.array(data['pred'])\n",
|
|
"# label = np.array(data['label'])\n",
|
|
"\n",
|
|
"# pred = np.argmax(pred, axis=1)\n",
|
|
"# label[label != 0] = 1\n",
|
|
"# slides = np.array([int(item.split('_')[0]) for item in name])\n",
|
|
"# idx = [item not in [4, 8] for item in slides]\n",
|
|
"# lb = label[idx]\n",
|
|
"# pd = pred[idx]\n",
|
|
"# accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd), confusion_matrix(lb, pd)\n",
|
|
"\n",
|
|
"# for item in set(slides):\n",
|
|
"# if item in [4, 8]:\n",
|
|
"# continue\n",
|
|
"# idx = slides == item\n",
|
|
"# lb = label[idx]\n",
|
|
"# pd = pred[idx]\n",
|
|
"# print(item, accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "8a45ddee-174c-4ae1-b34b-4940e7e558ae",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# with open('./logs/TEM/version_0/test.json') as f:\n",
|
|
"# data = json.load(f)\n",
|
|
"\n",
|
|
"# name = np.array(data['name'])\n",
|
|
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"# pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)\n",
|
|
"\n",
|
|
"# ress = []\n",
|
|
"# for json_path in json_lst:\n",
|
|
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
"# if base_name in ['4', '8']:\n",
|
|
"# continue\n",
|
|
" \n",
|
|
"# points, edge_index, labels, lights = load_data(json_path)\n",
|
|
"# mask_gt = np.zeros((2048, 2048))\n",
|
|
"# mask_gt[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" \n",
|
|
"# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
"# labels = get_y_3(labels, edge_index)\n",
|
|
" \n",
|
|
"# mask_pd = np.zeros((2048, 2048))\n",
|
|
"# mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" \n",
|
|
"# h, w = np.where(mask_pd != 0)\n",
|
|
"# pd = mask_pd[h, w]\n",
|
|
"# gt = mask_gt[h, w]\n",
|
|
" \n",
|
|
"# res = [precision_score(gt, pd, average='macro'), recall_score(gt, pd, average='macro'), f1_score(gt, pd, average='macro')]\n",
|
|
" \n",
|
|
"# ress += [res]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "a874300a-7d54-4ca1-88cf-06a86004386d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# np.mean(ress, axis=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "cd11d150-56d8-433b-bb86-16d7d61ae2c9",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Stage Vis"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "090cfe10-e86a-4f81-9d8f-97872677e837",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# def plot_json(json_path):\n",
|
|
"# colors = ['red', 'yellow', 'blue']\n",
|
|
" \n",
|
|
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
"# img_path = json_path.replace('.json', '.jpg')\n",
|
|
"# img = np.array(Image.open(img_path))\n",
|
|
"# points, edge_index, labels, lights = load_data(json_path)\n",
|
|
"# mask_gt = np.zeros((2048, 2048))\n",
|
|
"# mask_gt[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" \n",
|
|
"# labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
"# labels = get_y_3(labels, edge_index)\n",
|
|
"\n",
|
|
"# sp = points[edge_index[0].astype(np.int32)]\n",
|
|
"# tp = points[edge_index[1].astype(np.int32)]\n",
|
|
" \n",
|
|
"# plt.figure(figsize=(24, 9))\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 1)\n",
|
|
"# plt.imshow(img, cmap='gray')\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Image')\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 2)\n",
|
|
"# plt.imshow(img, cmap='gray')\n",
|
|
"# for i in [0, 1, 2]:\n",
|
|
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=30, zorder=2)\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Points_Pred_'+base_name)\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 3)\n",
|
|
"# bg = np.zeros((2048, 2048)) + 255\n",
|
|
"# bg[0, 0] = 0\n",
|
|
"# plt.imshow(bg, cmap='gray')\n",
|
|
"# # plt.imshow(img, cmap='gray')\n",
|
|
"# for i in range(len(sp)):\n",
|
|
"# plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=1, c='green', zorder=1)\n",
|
|
" \n",
|
|
"# plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
|
|
"# for i in [0, 1, 2]:\n",
|
|
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=5, zorder=2)\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Edge_Pred_'+base_name)\n",
|
|
"# plt.tight_layout()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "5c33b139-b84b-44ce-aed7-e4c8f7613853",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# with open('./logs/TEM/version_0/test.json') as f:\n",
|
|
"# data = json.load(f)\n",
|
|
"\n",
|
|
"# name = np.array(data['name'])\n",
|
|
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"# pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "2b08db8e-57b4-4513-9045-7ae6cdec04fd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plot_json(json_lst[8])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "5f26f3be-f279-4b3d-8e39-3457424ddb4f",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Stage TSNE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "b840655a-015c-4453-9256-afe36e08cce5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# import torch\n",
|
|
"# import numpy as np\n",
|
|
"# import torch.nn as nn\n",
|
|
"# import collections\n",
|
|
"# from core.model import GNN\n",
|
|
"# from core.data import AtomDataset\n",
|
|
"# from sklearn.manifold import TSNE\n",
|
|
"# import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "1ddbfef7-57c4-46fd-a32c-bd82ad468b50",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ckpt = torch.load('./logs/TEM/version_0/checkpoints/epoch=17-val_loss=0.03-val_acc=0.99.ckpt')['state_dict']\n",
|
|
"# ckpt = collections.OrderedDict([(k.replace('model.', ''), v) for k, v in ckpt.items()])\n",
|
|
"# model = GNN()\n",
|
|
"# model.load_state_dict(ckpt, strict=False);\n",
|
|
"# model.fc = nn.Identity()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "75c9970d-a815-49c0-acce-b1de3d59458d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# ds = AtomDataset('../../data/gnn_data/test/')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "1ce37478-d35a-4deb-897d-326b45f7e0bc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# model.eval()\n",
|
|
"# gts = []\n",
|
|
"# features = []\n",
|
|
"# for data in ds:\n",
|
|
"# features += model(data.x, data.edge_index, None).detach().cpu().numpy().tolist()\n",
|
|
"# gts += [int(data.y[0])]\n",
|
|
"\n",
|
|
"# features = np.array(features)\n",
|
|
"# gts = np.array(gts)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "1d6df9cc-6c32-4393-8bef-04ce11145f1d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# tsne = TSNE(n_components=2).fit_transform(features)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "f5d9cd1a-b77c-45b4-b076-d14d635e23c5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# colors = ['green', 'red', 'blue']\n",
|
|
"# c = [colors[item] for item in gts]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "fc71be3d-8da1-4d5d-92c0-ed5cfe03f2e2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plt.scatter(tsne[:, 0], tsne[:, 1], c=gts)\n",
|
|
"# plt.axis('off')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "15626028-66f3-4157-aca2-81be60c7ebe2",
|
|
"metadata": {},
|
|
"source": [
|
|
"# E2E Score"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "92628e75-bb7d-450b-bf7e-d1493b79e479",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"6 (0.9776635698754694, 0.9546419610113878, 0.9660156249999999)\n",
|
|
"6 (0.024734982332155476, 0.06086956521739131, 0.035175879396984924)\n",
|
|
"6 (0.0, 0.0, 0)\n",
|
|
"2 (0.9947694691979853, 0.9890215716486903, 0.991887193355225)\n",
|
|
"2 (0.008658008658008658, 0.015037593984962405, 0.010989010989010988)\n",
|
|
"2 (0.0070921985815602835, 0.0043859649122807015, 0.0054200542005420045)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with open('./logs/0/version_0/e2e.json') as f:\n",
|
|
" data = json.load(f)\n",
|
|
"\n",
|
|
"name = np.array(data['name'])\n",
|
|
"pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"json_lst = glob.glob('../../data/gnn_data/e2e/raw/*.json', recursive=True)\n",
|
|
"\n",
|
|
"for json_path in json_lst:\n",
|
|
" base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
" points, edge_index, _, _ = load_data(json_path)\n",
|
|
" labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
" labels = get_y_3(labels, edge_index)\n",
|
|
" \n",
|
|
" mask_pd = np.zeros((2048, 2048))\n",
|
|
" mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
|
" \n",
|
|
" mask_gt = np.array(Image.open(json_path.replace('.json', '.png')))\n",
|
|
" \n",
|
|
" for i in range(1, 4):\n",
|
|
" print(base_name, get_metrics(mask_gt == i, mask_pd == i))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d432c4b0-9843-472b-91b3-ec57b076fa1f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"id": "0dc4a5cd-645e-41b5-9e93-e6a43329c4a3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# def plot3(json_path, _type='gt'):\n",
|
|
"# colors = ['red', 'yellow', 'blue']\n",
|
|
" \n",
|
|
"# base_name = json_path.split('/')[-1].split('.')[0]\n",
|
|
"# img_path = json_path.replace('.json', '.jpg')\n",
|
|
"# img = np.array(Image.open(img_path))\n",
|
|
"\n",
|
|
"# points, edge_index, labels, _ = load_data(json_path)\n",
|
|
"# # labels = np.array([pred_dict['{}_{}'.format(base_name, '_'.join(np.array(point, np.str_)))] for point in points])\n",
|
|
"# # labels = get_y_3(labels, edge_index)\n",
|
|
" \n",
|
|
"# sp = points[edge_index[0].astype(np.int32)]\n",
|
|
"# tp = points[edge_index[1].astype(np.int32)]\n",
|
|
" \n",
|
|
"# plt.figure(figsize=(24, 9))\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 1)\n",
|
|
"# plt.imshow(img, cmap='gray')\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Image')\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 2)\n",
|
|
"# plt.imshow(img, cmap='gray')\n",
|
|
"# for i in [0, 1, 2]:\n",
|
|
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=30, zorder=2)\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Points_Pred_'+base_name)\n",
|
|
"\n",
|
|
"# plt.subplot(1, 3, 3)\n",
|
|
"# bg = np.zeros((2048, 2048)) + 255\n",
|
|
"# bg[0, 0] = 0\n",
|
|
"# plt.imshow(bg, cmap='gray')\n",
|
|
"# # plt.imshow(img, cmap='gray')\n",
|
|
"# for i in range(len(sp)):\n",
|
|
"# plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=1, c='green', zorder=1)\n",
|
|
" \n",
|
|
"# plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
|
|
"# for i in [0, 1, 2]:\n",
|
|
"# plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c=colors[i], s=5, zorder=2)\n",
|
|
"# plt.axis('off')\n",
|
|
"# plt.title('Edge_Pred_'+base_name)\n",
|
|
"# plt.tight_layout()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"id": "cfd26d35-fc84-4cbf-bb76-382dcc23c497",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# with open('./logs/TEM/version_0/e2e.json') as f:\n",
|
|
"# data = json.load(f)\n",
|
|
"\n",
|
|
"# name = np.array(data['name'])\n",
|
|
"# pred = np.argmax(np.array(data['pred']), axis=1)\n",
|
|
"# pred_dict = dict(zip(name, pred))\n",
|
|
"\n",
|
|
"# json_lst = glob.glob('../../data/gnn_data/e2e/raw/*.json', recursive=True)\n",
|
|
"# json_path = json_lst[7]\n",
|
|
"# plot3(json_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"id": "3032d9f2-68e6-44c5-be1d-0f31887e8b72",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# json_lst = glob.glob('../../data/gnn_data/test/raw/*.json', recursive=True)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"id": "679a32db-5a80-41ca-acf3-e2cfec5d6d0c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# json_path = json_lst[7]\n",
|
|
"# plot3(json_path)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "07faa19f-7b77-48b3-858f-27eb98f35633",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "cmae",
|
|
"language": "python",
|
|
"name": "cmae"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|