atom-predict/egnn_v2/.ipynb_checkpoints/GrayScale-checkpoint.ipynb

212 lines
5.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "a462eb92-a60a-45c9-a7c9-9256109f9963",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from core.data import AtomDataset\n",
"from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix\n",
"from sklearn.ensemble import RandomForestClassifier"
]
},
{
"cell_type": "markdown",
"id": "a7f82353-0ef0-4fc9-91c3-3680625698dc",
"metadata": {},
"source": [
"# GS"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "08eef057-7728-43a5-a415-063c301fdab6",
"metadata": {},
"outputs": [],
"source": [
"DATA_PATH = '../../data/gnn_data/'\n",
"\n",
"train_dataset = AtomDataset(root='{}/train/'.format(DATA_PATH))\n",
"valid_dataset = AtomDataset(root='{}/valid/'.format(DATA_PATH))\n",
"test_dataset = AtomDataset(root='{}/test/'.format(DATA_PATH))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "afbfa5f7-61c7-49fb-b1ae-4027e882eab8",
"metadata": {},
"outputs": [],
"source": [
"train_x, train_y = train_dataset.x, train_dataset.y\n",
"valid_x, valid_y = valid_dataset.x, valid_dataset.y\n",
"test_x, test_y = test_dataset.x, test_dataset.y"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7ca0dbeb-8c8c-48f7-b48e-b959582e25f7",
"metadata": {},
"outputs": [],
"source": [
"valid_x = np.array([float(item.x[0][0].numpy()) for item in valid_dataset])\n",
"valid_y = valid_y.numpy()\n",
"\n",
"test_x = np.array([float(item.x[0][0].numpy()) for item in test_dataset])\n",
"test_y = test_y.numpy()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1430d9f8-8906-45c5-b98b-e23bd7e34044",
"metadata": {},
"outputs": [],
"source": [
"slides = np.array([int(item.split('_')[0]) for item in test_dataset.name])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "6f7bd837-c8b8-4d5e-a3d7-cdf094520655",
"metadata": {},
"outputs": [],
"source": [
"thres = 0.65\n",
"label = np.array(test_y != 0, np.uint8)\n",
"pred = np.array(test_x < thres, np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "b84cc344-f0d3-4631-a660-0f2ac40814cc",
"metadata": {},
"outputs": [],
"source": [
"idx = [item not in [4, 8] for item in slides]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c3cec205-f782-4def-9552-359435d52f6f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(0.9902898294079452,\n",
" 0.9094754653130288,\n",
" 0.997525518094649,\n",
" 0.9514677681073905,\n",
" array([[30328, 321],\n",
" [ 8, 3225]]))"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lb = label[idx]\n",
"pd = pred[idx]\n",
"accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd), confusion_matrix(lb, pd)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "43047b50-30c6-4957-90cf-d03c344cca0b",
"metadata": {},
"outputs": [],
"source": [
"# for item in set(slides):\n",
"# if item in [4, 8]:\n",
"# continue\n",
"# idx = slides == item\n",
"# lb = label[idx]\n",
"# pd = pred[idx]\n",
"# print(item, accuracy_score(lb, pd), precision_score(lb, pd), recall_score(lb, pd), f1_score(lb, pd))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "19ffec85-73db-46d8-a08d-e932d38ae2ab",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3e8f7ccb-a209-44c0-a00c-191f1b91429c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[40643, 534],\n",
" [ 8, 3875]])"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(np.array(test_y != 0, np.uint8), np.array(test_x < thres, np.uint8))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f01cf9b1-2176-47da-ac2f-acf6e9df09cd",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 5,
"id": "c516efc9-dcc9-47be-b167-c82843eaec8e",
"metadata": {},
"outputs": [],
"source": [
"# for i in range(50, 100):\n",
"# print(i/100, f1_score(np.array(valid_y != 0, np.uint8), np.array(valid_x < i/100, np.uint8)))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}