atom-predict/egnn_v2/.ipynb_checkpoints/GNN-checkpoint.ipynb

571 lines
14 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "733e1374-8631-422e-8af8-440f3d29758e",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from core.data import AtomDataset\n",
"from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix\n",
"\n",
"import timm\n",
"import torch.nn.functional as F\n",
"from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GATv2Conv"
]
},
{
"cell_type": "markdown",
"id": "a7a921a1-506e-4435-8512-53a3a9015600",
"metadata": {},
"source": [
"# Model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "15775ae9-8848-4376-999d-083d4aa27b0b",
"metadata": {},
"outputs": [],
"source": [
"class GNN(nn.Module):\n",
" def __init__(self):\n",
" super(GNN, self).__init__()\n",
" self.encoder = timm.create_model('resnet18', pretrained=True)\n",
" self.encoder.fc = nn.Identity()\n",
" \n",
" # self.conv1 = GCNConv(512, 512)\n",
" # self.conv2 = GCNConv(512, 512)\n",
" # self.conv3 = GCNConv(512, 512)\n",
" # self.fc = nn.Linear(1024, 2)\n",
" self.fc = nn.Linear(512, 2)\n",
" \n",
" \n",
" def forward(self, x, edge_index):\n",
" x_res = self.encoder(x)\n",
" \n",
" # x = F.relu(self.conv1(x_res, edge_index))\n",
" # x = F.relu(self.conv2(x, edge_index))\n",
" # x = self.conv3(x, edge_index)\n",
" # x = torch.concatenate([x_res, x], axis=1)\n",
" x = self.fc(x_res)\n",
" \n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "21762ece-e0a8-479e-bebb-f9082ea63b91",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "cfca09b4-c649-4f1f-8f05-79c50f59d322",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/train/')\n",
"eval_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/eval/')\n",
"test_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/test/')\n",
"\n",
"# train_dataset = AtomDataset(root='../../data/gnn_data/train/')\n",
"# eval_dataset = AtomDataset(root='../../data/gnn_data/valid/')\n",
"# test_dataset = AtomDataset(root='../../data/gnn_data/test/')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "002a5e35-a302-40cf-a757-a2ed86541a27",
"metadata": {},
"outputs": [],
"source": [
"model = GNN()\n",
"model = model.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "bf422764-9a88-4dea-96d2-f0ce9ff82be0",
"metadata": {},
"outputs": [],
"source": [
"weight = torch.FloatTensor([1., 1.]).cuda()\n",
" \n",
"criterion = nn.CrossEntropyLoss(weight=weight)\n",
"optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, weight_decay=5e-4)\n",
"# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "05cc5540-8986-48b2-889e-7c498af3b9b5",
"metadata": {},
"outputs": [],
"source": [
"def train():\n",
" model.train()\n",
" optimizer.zero_grad()\n",
" \n",
" for data in train_dataset:\n",
" data = data.cuda()\n",
" out = model(data.x, data.edge_index)\n",
" y = data.y.clone(); y[y != 0] = 1\n",
" loss = criterion(out, y)\n",
" \n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "96dfec96-dfe4-491d-bdfd-d0ac118ab290",
"metadata": {},
"outputs": [],
"source": [
"best_f1 = 0."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0474d86a-7fc3-4106-a579-308268f795c4",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"def evl(dtype):\n",
" global best_f1\n",
" model.eval()\n",
" \n",
" if dtype == 'train':\n",
" dataset = train_dataset\n",
" else:\n",
" dataset = eval_dataset\n",
" \n",
" outs = []\n",
" lbls = []\n",
" for data in dataset:\n",
" data = data.cuda()\n",
" outs += [model(data.x, data.edge_index)]\n",
" y = data.y.clone(); y[y != 0] = 1\n",
" lbls += [y]\n",
" \n",
" outs = torch.concat(outs)\n",
" lbls = torch.concat(lbls)\n",
" \n",
" preds = outs.argmax(1)\n",
" \n",
" lbls = lbls.cpu()\n",
" preds = preds.cpu()\n",
" \n",
" # test_acc = (preds == lbls).sum() / len(preds)\n",
" test_f1 = f1_score(lbls, preds)\n",
" \n",
" if test_f1 > best_f1:\n",
" best_f1 = test_f1\n",
" \n",
" torch.save(model.state_dict(), './gnn_best.pth')\n",
" print(dtype)\n",
" print(confusion_matrix(lbls.cpu().numpy(), preds.cpu().numpy()))\n",
" print('Save Dtype: {} F1: {}.'.format(dtype, float(test_f1)))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "96757e64-b6da-40e8-8834-670cbffa119e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 0, Loss: 1.1485445499420166\n",
"\n",
"Epoch 4, Loss: 0.37879377603530884\n",
"eval\n",
"[[ 20 5183]\n",
" [ 1 394]]\n",
"Save Dtype: eval F1: 0.1319490957803081.\n",
"\n",
"Epoch 8, Loss: 0.0703992247581482\n",
"eval\n",
"[[3721 1482]\n",
" [ 4 391]]\n",
"Save Dtype: eval F1: 0.3447971781305115.\n",
"\n",
"Epoch 12, Loss: 0.02669108472764492\n",
"eval\n",
"[[4711 492]\n",
" [ 18 377]]\n",
"Save Dtype: eval F1: 0.5965189873417722.\n",
"\n",
"Epoch 16, Loss: 0.028077654540538788\n",
"eval\n",
"[[5099 104]\n",
" [ 109 286]]\n",
"Save Dtype: eval F1: 0.7286624203821656.\n",
"\n",
"Epoch 20, Loss: 0.01263432390987873\n",
"eval\n",
"[[4987 216]\n",
" [ 26 369]]\n",
"Save Dtype: eval F1: 0.7530612244897958.\n",
"\n",
"Epoch 24, Loss: 0.00836279895156622\n",
"eval\n",
"[[5029 174]\n",
" [ 21 374]]\n",
"Save Dtype: eval F1: 0.7932131495227996.\n",
"\n",
"Epoch 28, Loss: 0.0064726052805781364\n",
"eval\n",
"[[5049 154]\n",
" [ 20 375]]\n",
"Save Dtype: eval F1: 0.8116883116883117.\n",
"\n",
"Epoch 32, Loss: 0.005404990166425705\n",
"eval\n",
"[[5063 140]\n",
" [ 17 378]]\n",
"Save Dtype: eval F1: 0.828039430449069.\n",
"\n",
"Epoch 36, Loss: 0.004729312378913164\n",
"eval\n",
"[[5085 118]\n",
" [ 16 379]]\n",
"Save Dtype: eval F1: 0.8497757847533631.\n",
"\n",
"Epoch 40, Loss: 0.004251164384186268\n",
"eval\n",
"[[5102 101]\n",
" [ 14 381]]\n",
"Save Dtype: eval F1: 0.8688711516533636.\n",
"\n",
"Epoch 44, Loss: 0.0038708539213985205\n",
"eval\n",
"[[5119 84]\n",
" [ 15 380]]\n",
"Save Dtype: eval F1: 0.8847497089639116.\n",
"\n",
"Epoch 48, Loss: 0.0035590955521911383\n",
"eval\n",
"[[5131 72]\n",
" [ 17 378]]\n",
"Save Dtype: eval F1: 0.8946745562130176.\n",
"\n",
"Epoch 52, Loss: 0.003302567871287465\n",
"eval\n",
"[[5145 58]\n",
" [ 22 373]]\n",
"Save Dtype: eval F1: 0.9031476997578693.\n",
"\n",
"Epoch 56, Loss: 0.003077542642131448\n",
"eval\n",
"[[5152 51]\n",
" [ 24 371]]\n",
"Save Dtype: eval F1: 0.9082007343941249.\n",
"\n",
"Epoch 60, Loss: 0.0028757487889379263\n",
"eval\n",
"[[5158 45]\n",
" [ 26 369]]\n",
"Save Dtype: eval F1: 0.9122373300370827.\n",
"\n",
"Epoch 64, Loss: 0.002683074912056327\n",
"\n",
"Epoch 68, Loss: 0.0025064863730221987\n",
"\n",
"Epoch 72, Loss: 0.002350787864997983\n",
"\n",
"Epoch 76, Loss: 0.0022105583921074867\n",
"\n",
"Epoch 80, Loss: 0.0020780975464731455\n",
"\n",
"Epoch 84, Loss: 0.001956102205440402\n",
"eval\n",
"[[5161 42]\n",
" [ 28 367]]\n",
"Save Dtype: eval F1: 0.9129353233830845.\n",
"\n",
"Epoch 88, Loss: 0.0018357443623244762\n",
"eval\n",
"[[5162 41]\n",
" [ 28 367]]\n",
"Save Dtype: eval F1: 0.9140722291407223.\n",
"\n",
"Epoch 92, Loss: 0.0017174314707517624\n",
"\n",
"Epoch 96, Loss: 0.0015771074686199427\n",
"\n",
"Epoch 100, Loss: 0.0013864610809832811\n",
"\n",
"Epoch 104, Loss: 0.0012668231502175331\n",
"\n",
"Epoch 108, Loss: 0.001196257653646171\n",
"\n",
"Epoch 112, Loss: 0.0011406401172280312\n",
"\n",
"Epoch 116, Loss: 0.0010921088978648186\n",
"\n",
"Epoch 120, Loss: 0.0010481290519237518\n",
"\n",
"Epoch 124, Loss: 0.0010083065135404468\n",
"\n",
"Epoch 128, Loss: 0.0009702107636258006\n",
"\n",
"Epoch 132, Loss: 0.0009348472231067717\n",
"\n",
"Epoch 136, Loss: 0.0009008236229419708\n",
"\n",
"Epoch 140, Loss: 0.0008667901856824756\n",
"\n",
"Epoch 144, Loss: 0.0008340950589627028\n",
"\n",
"Epoch 148, Loss: 0.0008010714082047343\n",
"\n",
"Epoch 152, Loss: 0.0007691137725487351\n",
"\n",
"Epoch 156, Loss: 0.0007390899700112641\n",
"\n",
"Epoch 160, Loss: 0.0007107618148438632\n",
"\n",
"Epoch 164, Loss: 0.0006838308181613684\n",
"\n",
"Epoch 168, Loss: 0.0006599965272471309\n",
"\n",
"Epoch 172, Loss: 0.0006382830324582756\n",
"\n",
"Epoch 176, Loss: 0.0006170718697831035\n",
"\n",
"Epoch 180, Loss: 0.0005973792285658419\n",
"\n",
"Epoch 184, Loss: 0.0005795774632133543\n",
"\n",
"Epoch 188, Loss: 0.0005632488173432648\n",
"\n",
"Epoch 192, Loss: 0.0006092883995734155\n",
"\n",
"Epoch 196, Loss: 0.004182583186775446\n",
"\n",
"Epoch 200, Loss: 0.0012245809193700552\n",
"\n"
]
}
],
"source": [
"for epoch in range(201):\n",
" loss = train()\n",
" \n",
" if epoch % 4 == 0:\n",
" print('Epoch {}, Loss: {}'.format(epoch, loss))\n",
" \n",
" if epoch % 4 == 0:\n",
" evl('eval')\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "37749f32-6384-4136-84cc-0c7760bd460d",
"metadata": {},
"outputs": [],
"source": [
"ckpt = torch.load('./gnn_best.pth')\n",
"model.load_state_dict(ckpt);"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "89bb734d-3df9-4a14-9450-a03f1a21854f",
"metadata": {},
"outputs": [],
"source": [
"model.eval()\n",
"dataset = test_dataset\n",
"\n",
"outs = []\n",
"lbls = []\n",
"for data in dataset:\n",
" data = data.cuda()\n",
" outs += [model(data.x, data.edge_index)]\n",
" y = data.y.clone(); y[y != 0] = 1\n",
" lbls += [y]\n",
"\n",
"outs = torch.concat(outs)\n",
"lbls = torch.concat(lbls)\n",
"\n",
"preds = outs.argmax(1)\n",
"test_acc = (preds == lbls).sum() / len(preds)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "11f0ea19-87fd-409b-9fc4-4afd1d709383",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.9892, device='cuda:0')"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_acc"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "e0372274-bb77-4585-8de2-caa61c1167f2",
"metadata": {},
"outputs": [],
"source": [
"pts = data.points\n",
"gts = data.y.detach().cpu().numpy()\n",
"lbls = lbls.detach().cpu().numpy()\n",
"preds = preds.detach().cpu().numpy()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "27599106-1bae-4982-b7a9-842784a2b420",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9156061620897522"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"precision_score(lbls, preds)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "ae725636-1587-4d56-a817-ab9a0b4b8e40",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9592982456140351"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"recall_score(lbls, preds)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e0998c0a-4734-496d-89c6-92e4f3cfcc0c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0.9369431117203564"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f1_score(lbls, preds)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dfa790b6-f5e1-47db-8dba-5a18d0951002",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a321d50-0f40-43e0-8a0f-13bd72a9e32a",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e24ff84c-d069-4197-9535-f0cc58c7a750",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "a95978db-5d5b-4334-8b71-95c6d716558b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "targetdif",
"language": "python",
"name": "targetdif"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}