303 lines
7.6 KiB
Plaintext
303 lines
7.6 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "1b9d61dc-a2c3-4400-aa3e-ead3f11a30e5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import json\n",
|
|
"import torch\n",
|
|
"import collections\n",
|
|
"import torch.nn as nn\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"\n",
|
|
"from sklearn.manifold import TSNE\n",
|
|
"from core.model import PL_EGNN\n",
|
|
"from core.data import AtomDataset\n",
|
|
"from sklearn.metrics import f1_score, confusion_matrix"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "28e77574-9a3b-4bfc-8cfa-0f18391c2153",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ds = AtomDataset('../../data/gnn_data/test/')\n",
|
|
"model = PL_EGNN()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "61e67118-eeef-46cb-9ed5-daf70d6a526b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ckpt_path = './logs/0/version_0/checkpoints/epoch=15-val_loss=0.03-val_acc=0.99.ckpt'\n",
|
|
"ckpt = torch.load(ckpt_path)['state_dict']\n",
|
|
"ckpt = collections.OrderedDict([(k.replace('model.model', 'model'), v) for k, v in ckpt.items()])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "9b2b1dc0-edf8-4b23-b594-f2899f623621",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.load_state_dict(ckpt, strict=False);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "46cfc31e-9622-4114-a613-a310193b90ed",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"\n",
|
|
"hs = []\n",
|
|
"ys = []\n",
|
|
"names = []\n",
|
|
"\n",
|
|
"for data in ds:\n",
|
|
" with torch.no_grad():\n",
|
|
" light = data.light\n",
|
|
" pos = data.pos\n",
|
|
" edge_index = data.edge_index\n",
|
|
" name = data.name\n",
|
|
" \n",
|
|
" y = [int(data.label[0])]\n",
|
|
" h = [model(light, pos, edge_index)[0].numpy().tolist()]\n",
|
|
" \n",
|
|
" ys += y\n",
|
|
" hs += h\n",
|
|
" names += [name]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "46396926-df6b-4933-95f3-0adbd1245ce6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ys = np.array(ys)\n",
|
|
"hs = np.array(hs)\n",
|
|
"names = np.array(names)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2b636ddf-ca4b-4d41-bd5b-47532efd7dc0",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Metrics - 10"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "1b2531f1-99fd-4933-b031-4713214be4be",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"idx = [name.split('_')[0] == '10' for name in names]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "e4a5dce8-1cab-4668-8f24-e443eaffdc17",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"pred = np.argmax(hs, axis=1)\n",
|
|
"\n",
|
|
"y_gt = ys[idx]\n",
|
|
"y_pd = pred[idx]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "c594622a-02cd-4403-9c41-1b249023baa6",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9756108662005913"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"f1_score(y_gt, y_pd, average='macro')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "eb1c95c4-5989-4639-8909-2f8fcb2e2620",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[5107, 2, 9],\n",
|
|
" [ 1, 102, 6],\n",
|
|
" [ 2, 1, 350]])"
|
|
]
|
|
},
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"confusion_matrix(y_gt, y_pd)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c379b3ea-b589-4928-92c5-bce50a620cd0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "d81dc85b-214a-4466-bec3-672f36803dce",
|
|
"metadata": {},
|
|
"source": [
|
|
"# PLOT"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "ce981ea8-9714-4119-9afb-8b828b2a90f2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from PIL import Image\n",
|
|
"c_dict = {\n",
|
|
" 0: '#9BB6CF',\n",
|
|
" 1: '#76F1A2',\n",
|
|
" 2: '#EDC08C'\n",
|
|
"}\n",
|
|
"\n",
|
|
"cs = [c_dict[item] for item in y_pd]\n",
|
|
"\n",
|
|
"with open('../../data/gnn_data/test/raw/10.json') as f:\n",
|
|
" data = json.load(f)\n",
|
|
" \n",
|
|
" \n",
|
|
"img_ori = np.array(Image.open('../../data/gnn_data/test/raw/10.jpg'))\n",
|
|
"pts = np.array([item['points'][0] for item in data['shapes']], np.int32)\n",
|
|
"\n",
|
|
"img = np.array(Image.open('../../data/gnn_data/test/raw/10.jpg'))\n",
|
|
"plt.figure(figsize=(9, 9))\n",
|
|
"plt.imshow(img, cmap='gray')\n",
|
|
"plt.scatter(pts[:, 0], pts[:, 1], c=cs, s=24)\n",
|
|
"plt.axis('off')\n",
|
|
"plt.tight_layout()\n",
|
|
"# plt.savefig('./egnn_10.jpg', bbox_inches='tight', dpi=300)\n",
|
|
"# plt.close()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1b134a83-6fe4-4655-bd68-3e5e8af59636",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "3b8f5872-3e30-4f6a-b73f-accc3bbc016b",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"source": [
|
|
"# TSNE"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "83b4fc09-b7ef-4262-af53-48e743449b50",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"ename": "TypeError",
|
|
"evalue": "wrapped() missing 1 required positional argument: 'X'",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[11], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m tsne \u001b[38;5;241m=\u001b[39m TSNE(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m X_tsne \u001b[38;5;241m=\u001b[39m \u001b[43mtsne\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m plt\u001b[38;5;241m.\u001b[39mscatter(X_tsne[:,\u001b[38;5;241m0\u001b[39m], X_tsne[:,\u001b[38;5;241m1\u001b[39m], c\u001b[38;5;241m=\u001b[39mys)\n\u001b[1;32m 5\u001b[0m plt\u001b[38;5;241m.\u001b[39maxis(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124moff\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
|
|
"\u001b[0;31mTypeError\u001b[0m: wrapped() missing 1 required positional argument: 'X'"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"tsne = TSNE(n_components=2)\n",
|
|
"X_tsne = tsne.fit_transform()\n",
|
|
"\n",
|
|
"plt.scatter(X_tsne[:,0], X_tsne[:,1], c=ys)\n",
|
|
"plt.axis('off')\n",
|
|
"\n",
|
|
"# import pandas as pd\n",
|
|
"\n",
|
|
"# df = pd.DataFrame()\n",
|
|
"# df['x1'] = X_tsne[:, 0]\n",
|
|
"# df['x2'] = X_tsne[:, 1]\n",
|
|
"# df['label'] = ys\n",
|
|
"# df.to_csv('./test.txt', index=None)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "041b1eba-cc0d-401e-94a3-cdcf73c1a952",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "cmae",
|
|
"language": "python",
|
|
"name": "cmae"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|