571 lines
14 KiB
Plaintext
571 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "733e1374-8631-422e-8af8-440f3d29758e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"from core.data import AtomDataset\n",
|
|
"from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix\n",
|
|
"\n",
|
|
"import timm\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from torch_geometric.nn import GCNConv, SAGEConv, GATConv, GATv2Conv"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a7a921a1-506e-4435-8512-53a3a9015600",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Model"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "15775ae9-8848-4376-999d-083d4aa27b0b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"class GNN(nn.Module):\n",
|
|
" def __init__(self):\n",
|
|
" super(GNN, self).__init__()\n",
|
|
" self.encoder = timm.create_model('resnet18', pretrained=True)\n",
|
|
" self.encoder.fc = nn.Identity()\n",
|
|
" \n",
|
|
" # self.conv1 = GCNConv(512, 512)\n",
|
|
" # self.conv2 = GCNConv(512, 512)\n",
|
|
" # self.conv3 = GCNConv(512, 512)\n",
|
|
" # self.fc = nn.Linear(1024, 2)\n",
|
|
" self.fc = nn.Linear(512, 2)\n",
|
|
" \n",
|
|
" \n",
|
|
" def forward(self, x, edge_index):\n",
|
|
" x_res = self.encoder(x)\n",
|
|
" \n",
|
|
" # x = F.relu(self.conv1(x_res, edge_index))\n",
|
|
" # x = F.relu(self.conv2(x, edge_index))\n",
|
|
" # x = self.conv3(x, edge_index)\n",
|
|
" # x = torch.concatenate([x_res, x], axis=1)\n",
|
|
" x = self.fc(x_res)\n",
|
|
" \n",
|
|
" return x"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "21762ece-e0a8-479e-bebb-f9082ea63b91",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Train"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "cfca09b4-c649-4f1f-8f05-79c50f59d322",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/train/')\n",
|
|
"eval_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/eval/')\n",
|
|
"test_dataset = AtomDataset(root='../../../bk/v15_Final/data/gnn_data/test/')\n",
|
|
"\n",
|
|
"# train_dataset = AtomDataset(root='../../data/gnn_data/train/')\n",
|
|
"# eval_dataset = AtomDataset(root='../../data/gnn_data/valid/')\n",
|
|
"# test_dataset = AtomDataset(root='../../data/gnn_data/test/')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "002a5e35-a302-40cf-a757-a2ed86541a27",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = GNN()\n",
|
|
"model = model.cuda()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "bf422764-9a88-4dea-96d2-f0ce9ff82be0",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"weight = torch.FloatTensor([1., 1.]).cuda()\n",
|
|
" \n",
|
|
"criterion = nn.CrossEntropyLoss(weight=weight)\n",
|
|
"optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-4, weight_decay=5e-4)\n",
|
|
"# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "05cc5540-8986-48b2-889e-7c498af3b9b5",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train():\n",
|
|
" model.train()\n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" for data in train_dataset:\n",
|
|
" data = data.cuda()\n",
|
|
" out = model(data.x, data.edge_index)\n",
|
|
" y = data.y.clone(); y[y != 0] = 1\n",
|
|
" loss = criterion(out, y)\n",
|
|
" \n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" return loss"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "96dfec96-dfe4-491d-bdfd-d0ac118ab290",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"best_f1 = 0."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "0474d86a-7fc3-4106-a579-308268f795c4",
|
|
"metadata": {
|
|
"tags": []
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evl(dtype):\n",
|
|
" global best_f1\n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" if dtype == 'train':\n",
|
|
" dataset = train_dataset\n",
|
|
" else:\n",
|
|
" dataset = eval_dataset\n",
|
|
" \n",
|
|
" outs = []\n",
|
|
" lbls = []\n",
|
|
" for data in dataset:\n",
|
|
" data = data.cuda()\n",
|
|
" outs += [model(data.x, data.edge_index)]\n",
|
|
" y = data.y.clone(); y[y != 0] = 1\n",
|
|
" lbls += [y]\n",
|
|
" \n",
|
|
" outs = torch.concat(outs)\n",
|
|
" lbls = torch.concat(lbls)\n",
|
|
" \n",
|
|
" preds = outs.argmax(1)\n",
|
|
" \n",
|
|
" lbls = lbls.cpu()\n",
|
|
" preds = preds.cpu()\n",
|
|
" \n",
|
|
" # test_acc = (preds == lbls).sum() / len(preds)\n",
|
|
" test_f1 = f1_score(lbls, preds)\n",
|
|
" \n",
|
|
" if test_f1 > best_f1:\n",
|
|
" best_f1 = test_f1\n",
|
|
" \n",
|
|
" torch.save(model.state_dict(), './gnn_best.pth')\n",
|
|
" print(dtype)\n",
|
|
" print(confusion_matrix(lbls.cpu().numpy(), preds.cpu().numpy()))\n",
|
|
" print('Save Dtype: {} F1: {}.'.format(dtype, float(test_f1)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "96757e64-b6da-40e8-8834-670cbffa119e",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch 0, Loss: 1.1485445499420166\n",
|
|
"\n",
|
|
"Epoch 4, Loss: 0.37879377603530884\n",
|
|
"eval\n",
|
|
"[[ 20 5183]\n",
|
|
" [ 1 394]]\n",
|
|
"Save Dtype: eval F1: 0.1319490957803081.\n",
|
|
"\n",
|
|
"Epoch 8, Loss: 0.0703992247581482\n",
|
|
"eval\n",
|
|
"[[3721 1482]\n",
|
|
" [ 4 391]]\n",
|
|
"Save Dtype: eval F1: 0.3447971781305115.\n",
|
|
"\n",
|
|
"Epoch 12, Loss: 0.02669108472764492\n",
|
|
"eval\n",
|
|
"[[4711 492]\n",
|
|
" [ 18 377]]\n",
|
|
"Save Dtype: eval F1: 0.5965189873417722.\n",
|
|
"\n",
|
|
"Epoch 16, Loss: 0.028077654540538788\n",
|
|
"eval\n",
|
|
"[[5099 104]\n",
|
|
" [ 109 286]]\n",
|
|
"Save Dtype: eval F1: 0.7286624203821656.\n",
|
|
"\n",
|
|
"Epoch 20, Loss: 0.01263432390987873\n",
|
|
"eval\n",
|
|
"[[4987 216]\n",
|
|
" [ 26 369]]\n",
|
|
"Save Dtype: eval F1: 0.7530612244897958.\n",
|
|
"\n",
|
|
"Epoch 24, Loss: 0.00836279895156622\n",
|
|
"eval\n",
|
|
"[[5029 174]\n",
|
|
" [ 21 374]]\n",
|
|
"Save Dtype: eval F1: 0.7932131495227996.\n",
|
|
"\n",
|
|
"Epoch 28, Loss: 0.0064726052805781364\n",
|
|
"eval\n",
|
|
"[[5049 154]\n",
|
|
" [ 20 375]]\n",
|
|
"Save Dtype: eval F1: 0.8116883116883117.\n",
|
|
"\n",
|
|
"Epoch 32, Loss: 0.005404990166425705\n",
|
|
"eval\n",
|
|
"[[5063 140]\n",
|
|
" [ 17 378]]\n",
|
|
"Save Dtype: eval F1: 0.828039430449069.\n",
|
|
"\n",
|
|
"Epoch 36, Loss: 0.004729312378913164\n",
|
|
"eval\n",
|
|
"[[5085 118]\n",
|
|
" [ 16 379]]\n",
|
|
"Save Dtype: eval F1: 0.8497757847533631.\n",
|
|
"\n",
|
|
"Epoch 40, Loss: 0.004251164384186268\n",
|
|
"eval\n",
|
|
"[[5102 101]\n",
|
|
" [ 14 381]]\n",
|
|
"Save Dtype: eval F1: 0.8688711516533636.\n",
|
|
"\n",
|
|
"Epoch 44, Loss: 0.0038708539213985205\n",
|
|
"eval\n",
|
|
"[[5119 84]\n",
|
|
" [ 15 380]]\n",
|
|
"Save Dtype: eval F1: 0.8847497089639116.\n",
|
|
"\n",
|
|
"Epoch 48, Loss: 0.0035590955521911383\n",
|
|
"eval\n",
|
|
"[[5131 72]\n",
|
|
" [ 17 378]]\n",
|
|
"Save Dtype: eval F1: 0.8946745562130176.\n",
|
|
"\n",
|
|
"Epoch 52, Loss: 0.003302567871287465\n",
|
|
"eval\n",
|
|
"[[5145 58]\n",
|
|
" [ 22 373]]\n",
|
|
"Save Dtype: eval F1: 0.9031476997578693.\n",
|
|
"\n",
|
|
"Epoch 56, Loss: 0.003077542642131448\n",
|
|
"eval\n",
|
|
"[[5152 51]\n",
|
|
" [ 24 371]]\n",
|
|
"Save Dtype: eval F1: 0.9082007343941249.\n",
|
|
"\n",
|
|
"Epoch 60, Loss: 0.0028757487889379263\n",
|
|
"eval\n",
|
|
"[[5158 45]\n",
|
|
" [ 26 369]]\n",
|
|
"Save Dtype: eval F1: 0.9122373300370827.\n",
|
|
"\n",
|
|
"Epoch 64, Loss: 0.002683074912056327\n",
|
|
"\n",
|
|
"Epoch 68, Loss: 0.0025064863730221987\n",
|
|
"\n",
|
|
"Epoch 72, Loss: 0.002350787864997983\n",
|
|
"\n",
|
|
"Epoch 76, Loss: 0.0022105583921074867\n",
|
|
"\n",
|
|
"Epoch 80, Loss: 0.0020780975464731455\n",
|
|
"\n",
|
|
"Epoch 84, Loss: 0.001956102205440402\n",
|
|
"eval\n",
|
|
"[[5161 42]\n",
|
|
" [ 28 367]]\n",
|
|
"Save Dtype: eval F1: 0.9129353233830845.\n",
|
|
"\n",
|
|
"Epoch 88, Loss: 0.0018357443623244762\n",
|
|
"eval\n",
|
|
"[[5162 41]\n",
|
|
" [ 28 367]]\n",
|
|
"Save Dtype: eval F1: 0.9140722291407223.\n",
|
|
"\n",
|
|
"Epoch 92, Loss: 0.0017174314707517624\n",
|
|
"\n",
|
|
"Epoch 96, Loss: 0.0015771074686199427\n",
|
|
"\n",
|
|
"Epoch 100, Loss: 0.0013864610809832811\n",
|
|
"\n",
|
|
"Epoch 104, Loss: 0.0012668231502175331\n",
|
|
"\n",
|
|
"Epoch 108, Loss: 0.001196257653646171\n",
|
|
"\n",
|
|
"Epoch 112, Loss: 0.0011406401172280312\n",
|
|
"\n",
|
|
"Epoch 116, Loss: 0.0010921088978648186\n",
|
|
"\n",
|
|
"Epoch 120, Loss: 0.0010481290519237518\n",
|
|
"\n",
|
|
"Epoch 124, Loss: 0.0010083065135404468\n",
|
|
"\n",
|
|
"Epoch 128, Loss: 0.0009702107636258006\n",
|
|
"\n",
|
|
"Epoch 132, Loss: 0.0009348472231067717\n",
|
|
"\n",
|
|
"Epoch 136, Loss: 0.0009008236229419708\n",
|
|
"\n",
|
|
"Epoch 140, Loss: 0.0008667901856824756\n",
|
|
"\n",
|
|
"Epoch 144, Loss: 0.0008340950589627028\n",
|
|
"\n",
|
|
"Epoch 148, Loss: 0.0008010714082047343\n",
|
|
"\n",
|
|
"Epoch 152, Loss: 0.0007691137725487351\n",
|
|
"\n",
|
|
"Epoch 156, Loss: 0.0007390899700112641\n",
|
|
"\n",
|
|
"Epoch 160, Loss: 0.0007107618148438632\n",
|
|
"\n",
|
|
"Epoch 164, Loss: 0.0006838308181613684\n",
|
|
"\n",
|
|
"Epoch 168, Loss: 0.0006599965272471309\n",
|
|
"\n",
|
|
"Epoch 172, Loss: 0.0006382830324582756\n",
|
|
"\n",
|
|
"Epoch 176, Loss: 0.0006170718697831035\n",
|
|
"\n",
|
|
"Epoch 180, Loss: 0.0005973792285658419\n",
|
|
"\n",
|
|
"Epoch 184, Loss: 0.0005795774632133543\n",
|
|
"\n",
|
|
"Epoch 188, Loss: 0.0005632488173432648\n",
|
|
"\n",
|
|
"Epoch 192, Loss: 0.0006092883995734155\n",
|
|
"\n",
|
|
"Epoch 196, Loss: 0.004182583186775446\n",
|
|
"\n",
|
|
"Epoch 200, Loss: 0.0012245809193700552\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for epoch in range(201):\n",
|
|
" loss = train()\n",
|
|
" \n",
|
|
" if epoch % 4 == 0:\n",
|
|
" print('Epoch {}, Loss: {}'.format(epoch, loss))\n",
|
|
" \n",
|
|
" if epoch % 4 == 0:\n",
|
|
" evl('eval')\n",
|
|
" print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"id": "37749f32-6384-4136-84cc-0c7760bd460d",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ckpt = torch.load('./gnn_best.pth')\n",
|
|
"model.load_state_dict(ckpt);"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"id": "89bb734d-3df9-4a14-9450-a03f1a21854f",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.eval()\n",
|
|
"dataset = test_dataset\n",
|
|
"\n",
|
|
"outs = []\n",
|
|
"lbls = []\n",
|
|
"for data in dataset:\n",
|
|
" data = data.cuda()\n",
|
|
" outs += [model(data.x, data.edge_index)]\n",
|
|
" y = data.y.clone(); y[y != 0] = 1\n",
|
|
" lbls += [y]\n",
|
|
"\n",
|
|
"outs = torch.concat(outs)\n",
|
|
"lbls = torch.concat(lbls)\n",
|
|
"\n",
|
|
"preds = outs.argmax(1)\n",
|
|
"test_acc = (preds == lbls).sum() / len(preds)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"id": "11f0ea19-87fd-409b-9fc4-4afd1d709383",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"tensor(0.9892, device='cuda:0')"
|
|
]
|
|
},
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"test_acc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"id": "e0372274-bb77-4585-8de2-caa61c1167f2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"pts = data.points\n",
|
|
"gts = data.y.detach().cpu().numpy()\n",
|
|
"lbls = lbls.detach().cpu().numpy()\n",
|
|
"preds = preds.detach().cpu().numpy()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"id": "27599106-1bae-4982-b7a9-842784a2b420",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9156061620897522"
|
|
]
|
|
},
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"precision_score(lbls, preds)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"id": "ae725636-1587-4d56-a817-ab9a0b4b8e40",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9592982456140351"
|
|
]
|
|
},
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"recall_score(lbls, preds)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"id": "e0998c0a-4734-496d-89c6-92e4f3cfcc0c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"0.9369431117203564"
|
|
]
|
|
},
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"f1_score(lbls, preds)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "dfa790b6-f5e1-47db-8dba-5a18d0951002",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "3a321d50-0f40-43e0-8a0f-13bd72a9e32a",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "e24ff84c-d069-4197-9535-f0cc58c7a750",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "a95978db-5d5b-4334-8b71-95c6d716558b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "targetdif",
|
|
"language": "python",
|
|
"name": "targetdif"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.11"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|