atom-predict/egnn_v2/.ipynb_checkpoints/VisS1-checkpoint.ipynb

143 lines
3.7 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "00bd37c1-6d87-4fba-9f64-22c287e6c014",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image\n",
"from utils.e2e_metrics import get_metrics\n",
"from core.data import get_y_3\n",
"from core.data import load_data\n",
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a349f7b4-4979-4ef3-86af-f2d89f6edc85",
"metadata": {},
"outputs": [],
"source": [
"def plot_json(json_path):\n",
" \n",
" base_name = json_path.split('/')[-1].split('.')[0]\n",
" img_path = json_path.replace('.json', '.jpg')\n",
" img = np.array(Image.open(img_path))\n",
" points, edge_index, labels, lights = load_data(json_path)\n",
" mask_gt = np.zeros((2048, 2048))\n",
" mask_gt[points[:, 0], points[:, 1]] = labels + 1\n",
" \n",
" sp = points[edge_index[0].astype(np.int32)]\n",
" tp = points[edge_index[1].astype(np.int32)]\n",
" \n",
" plt.figure(figsize=(16, 9))\n",
" plt.subplot(1, 2, 1)\n",
" plt.imshow(img, cmap='gray')\n",
" plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2, c='red')\n",
" plt.axis('off')\n",
" plt.title('Points_Pred_'+base_name)\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" bg = np.zeros((2048, 2048)) + 255\n",
" bg[0, 0] = 0\n",
" plt.imshow(bg, cmap='gray')\n",
" # plt.imshow(img, cmap='gray')\n",
" for i in range(len(sp)):\n",
" plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=1, c='green', zorder=1)\n",
" \n",
" plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
" plt.axis('off')\n",
" plt.title('Edge_Pred_'+base_name)\n",
" plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7902c91f-06e0-4582-b8cb-af64bb5dafe5",
"metadata": {},
"outputs": [],
"source": [
"json_lst = glob.glob('../../data/gnn_data/e2e/raw/*.json', recursive=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e89f6c06-9fdc-421f-9b6a-9c1cbcf7c857",
"metadata": {},
"outputs": [],
"source": [
"name = '4'\n",
"i = [name+'.json' in item for item in json_lst].index(True)\n",
"\n",
"plot_json(json_lst[i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "649a93ff-a0ae-4ba5-a04f-6d10bd43217b",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "e37863b2-f376-4568-a57e-47031cf98a50",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9864661-3bf7-476f-bef5-674ef9dfd7b0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7515d855-5674-4ea3-ae52-787185cdf5db",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}