atom-predict/egnn_v2/.ipynb_checkpoints/MetricsTest-checkpoint.ipynb

303 lines
7.6 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1b9d61dc-a2c3-4400-aa3e-ead3f11a30e5",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import torch\n",
"import collections\n",
"import torch.nn as nn\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sklearn.manifold import TSNE\n",
"from core.model import PL_EGNN\n",
"from core.data import AtomDataset\n",
"from sklearn.metrics import f1_score, confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "28e77574-9a3b-4bfc-8cfa-0f18391c2153",
"metadata": {},
"outputs": [],
"source": [
"ds = AtomDataset('../../data/gnn_data/test/')\n",
"model = PL_EGNN()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "61e67118-eeef-46cb-9ed5-daf70d6a526b",
"metadata": {},
"outputs": [],
"source": [
"ckpt_path = './logs/0/version_0/checkpoints/epoch=15-val_loss=0.03-val_acc=0.99.ckpt'\n",
"ckpt = torch.load(ckpt_path)['state_dict']\n",
"ckpt = collections.OrderedDict([(k.replace('model.model', 'model'), v) for k, v in ckpt.items()])"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9b2b1dc0-edf8-4b23-b594-f2899f623621",
"metadata": {},
"outputs": [],
"source": [
"model.load_state_dict(ckpt, strict=False);"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "46cfc31e-9622-4114-a613-a310193b90ed",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"\n",
"hs = []\n",
"ys = []\n",
"names = []\n",
"\n",
"for data in ds:\n",
" with torch.no_grad():\n",
" light = data.light\n",
" pos = data.pos\n",
" edge_index = data.edge_index\n",
" name = data.name\n",
" \n",
" y = [int(data.label[0])]\n",
" h = [model(light, pos, edge_index)[0].numpy().tolist()]\n",
" \n",
" ys += y\n",
" hs += h\n",
" names += [name]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "46396926-df6b-4933-95f3-0adbd1245ce6",
"metadata": {},
"outputs": [],
"source": [
"ys = np.array(ys)\n",
"hs = np.array(hs)\n",
"names = np.array(names)"
]
},
{
"cell_type": "markdown",
"id": "2b636ddf-ca4b-4d41-bd5b-47532efd7dc0",
"metadata": {},
"source": [
"# Metrics - 10"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "1b2531f1-99fd-4933-b031-4713214be4be",
"metadata": {},
"outputs": [],
"source": [
"idx = [name.split('_')[0] == '10' for name in names]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e4a5dce8-1cab-4668-8f24-e443eaffdc17",
"metadata": {},
"outputs": [],
"source": [
"pred = np.argmax(hs, axis=1)\n",
"\n",
"y_gt = ys[idx]\n",
"y_pd = pred[idx]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c594622a-02cd-4403-9c41-1b249023baa6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9756108662005913"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1_score(y_gt, y_pd, average='macro')"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "eb1c95c4-5989-4639-8909-2f8fcb2e2620",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[5107, 2, 9],\n",
" [ 1, 102, 6],\n",
" [ 2, 1, 350]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(y_gt, y_pd)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c379b3ea-b589-4928-92c5-bce50a620cd0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "d81dc85b-214a-4466-bec3-672f36803dce",
"metadata": {},
"source": [
"# PLOT"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ce981ea8-9714-4119-9afb-8b828b2a90f2",
"metadata": {},
"outputs": [],
"source": [
"from PIL import Image\n",
"c_dict = {\n",
" 0: '#9BB6CF',\n",
" 1: '#76F1A2',\n",
" 2: '#EDC08C'\n",
"}\n",
"\n",
"cs = [c_dict[item] for item in y_pd]\n",
"\n",
"with open('../../data/gnn_data/test/raw/10.json') as f:\n",
" data = json.load(f)\n",
" \n",
" \n",
"img_ori = np.array(Image.open('../../data/gnn_data/test/raw/10.jpg'))\n",
"pts = np.array([item['points'][0] for item in data['shapes']], np.int32)\n",
"\n",
"img = np.array(Image.open('../../data/gnn_data/test/raw/10.jpg'))\n",
"plt.figure(figsize=(9, 9))\n",
"plt.imshow(img, cmap='gray')\n",
"plt.scatter(pts[:, 0], pts[:, 1], c=cs, s=24)\n",
"plt.axis('off')\n",
"plt.tight_layout()\n",
"# plt.savefig('./egnn_10.jpg', bbox_inches='tight', dpi=300)\n",
"# plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b134a83-6fe4-4655-bd68-3e5e8af59636",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "3b8f5872-3e30-4f6a-b73f-accc3bbc016b",
"metadata": {
"tags": []
},
"source": [
"# TSNE"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "83b4fc09-b7ef-4262-af53-48e743449b50",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "wrapped() missing 1 required positional argument: 'X'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[11], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m tsne \u001b[38;5;241m=\u001b[39m TSNE(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m X_tsne \u001b[38;5;241m=\u001b[39m \u001b[43mtsne\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mscatter(X_tsne[:,\u001b[38;5;241m0\u001b[39m], X_tsne[:,\u001b[38;5;241m1\u001b[39m], c\u001b[38;5;241m=\u001b[39mys)\n\u001b[1;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
"\u001b[0;31mTypeError\u001b[0m: wrapped() missing 1 required positional argument: 'X'"
]
}
],
"source": [
"tsne = TSNE(n_components=2)\n",
"X_tsne = tsne.fit_transform()\n",
"\n",
"plt.scatter(X_tsne[:,0], X_tsne[:,1], c=ys)\n",
"plt.axis('off')\n",
"\n",
"# import pandas as pd\n",
"\n",
"# df = pd.DataFrame()\n",
"# df['x1'] = X_tsne[:, 0]\n",
"# df['x2'] = X_tsne[:, 1]\n",
"# df['label'] = ys\n",
"# df.to_csv('./test.txt', index=None)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "041b1eba-cc0d-401e-94a3-cdcf73c1a952",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}