212 lines
5.0 KiB
Plaintext
212 lines
5.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"id": "a462eb92-a60a-45c9-a7c9-9256109f9963",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import numpy as np\n",
|
|
"from core.data import AtomDataset\n",
|
|
"from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix\n",
|
|
"from sklearn.ensemble import RandomForestClassifier"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "a7f82353-0ef0-4fc9-91c3-3680625698dc",
|
|
"metadata": {},
|
|
"source": [
|
|
"# GS"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"id": "08eef057-7728-43a5-a415-063c301fdab6",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"DATA_PATH = '../../data/gnn_data/'\n",
|
|
"\n",
|
|
"train_dataset = AtomDataset(root='{}/train/'.format(DATA_PATH))\n",
|
|
"valid_dataset = AtomDataset(root='{}/valid/'.format(DATA_PATH))\n",
|
|
"test_dataset = AtomDataset(root='{}/test/'.format(DATA_PATH))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"id": "afbfa5f7-61c7-49fb-b1ae-4027e882eab8",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_x, train_y = train_dataset.x, train_dataset.y\n",
|
|
"valid_x, valid_y = valid_dataset.x, valid_dataset.y\n",
|
|
"test_x, test_y = test_dataset.x, test_dataset.y"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"id": "7ca0dbeb-8c8c-48f7-b48e-b959582e25f7",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"valid_x = np.array([float(item.x[0][0].numpy()) for item in valid_dataset])\n",
|
|
"valid_y = valid_y.numpy()\n",
|
|
"\n",
|
|
"test_x = np.array([float(item.x[0][0].numpy()) for item in test_dataset])\n",
|
|
"test_y = test_y.numpy()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "1430d9f8-8906-45c5-b98b-e23bd7e34044",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"slides = np.array([int(item.split('_')[0]) for item in test_dataset.name])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"id": "6f7bd837-c8b8-4d5e-a3d7-cdf094520655",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"thres = 0.65\n",
|
|
"label = np.array(test_y != 0, np.uint8)\n",
|
|
"pred = np.array(test_x < thres, np.uint8)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"id": "b84cc344-f0d3-4631-a660-0f2ac40814cc",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"idx = [item not in [4, 8] for item in slides]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"id": "c3cec205-f782-4def-9552-359435d52f6f",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"(0.9902898294079452,\n",
|
|
" 0.9094754653130288,\n",
|
|
" 0.997525518094649,\n",
|
|
" 0.9514677681073905,\n",
|
|
" array([[30328, 321],\n",
|
|
" [ 8, 3225]]))"
|
|
]
|
|
},
|
|
"execution_count": 8,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"lb = label[idx]\n",
|
|
"pd = pred[idx]\n",
|
|
"accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd), confusion_matrix(lb, pd)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "43047b50-30c6-4957-90cf-d03c344cca0b",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# for item in set(slides):\n",
|
|
"# if item in [4, 8]:\n",
|
|
"# continue\n",
|
|
"# idx = slides == item\n",
|
|
"# lb = label[idx]\n",
|
|
"# pd = pred[idx]\n",
|
|
"# print(item, accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "19ffec85-73db-46d8-a08d-e932d38ae2ab",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"id": "3e8f7ccb-a209-44c0-a00c-191f1b91429c",
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"array([[40643, 534],\n",
|
|
" [ 8, 3875]])"
|
|
]
|
|
},
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"confusion_matrix(np.array(test_y != 0, np.uint8), np.array(test_x < thres, np.uint8))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f01cf9b1-2176-47da-ac2f-acf6e9df09cd",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"id": "c516efc9-dcc9-47be-b167-c82843eaec8e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# for i in range(50, 100):\n",
|
|
"# print(i/100, f1_score(np.array(valid_y != 0, np.uint8), np.array(valid_x < i/100, np.uint8)))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "cmae",
|
|
"language": "python",
|
|
"name": "cmae"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|