atom-predict/egnn_v2/.ipynb_checkpoints/VisE2E-checkpoint.ipynb

131 lines
3.9 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "00bd37c1-6d87-4fba-9f64-22c287e6c014",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from PIL import Image\n",
"from utils.e2e_metrics import get_metrics\n",
"from core.data import get_y_3\n",
"from core.data import load_data\n",
"from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a349f7b4-4979-4ef3-86af-f2d89f6edc85",
"metadata": {},
"outputs": [],
"source": [
"def plot_json(base_name):\n",
" colors = ['#8EB9D9', '#92CE90', '#FBBE81']\n",
" img = cv2.imread('../../data/gnn_data/test/raw/{}.jpg'.format(base_name), 0)\n",
" bg = np.zeros((2048, 2048)) + 255\n",
" bg[0, 0] = 0\n",
" \n",
" plt.figure(figsize=(18, 9))\n",
" plt.subplot(1, 2, 1)\n",
" plt.imshow(img, cmap='gray')\n",
" json_path = '../../data/gnn_data/test/raw/{}.json'.format(base_name)\n",
" points, edge_index, labels, lights = load_data(json_path)\n",
" sp = points[edge_index[0].astype(np.int32)]\n",
" tp = points[edge_index[1].astype(np.int32)]\n",
" \n",
" for i in range(len(sp)):\n",
" plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=2, c='#C0C0C0', zorder=1)\n",
" \n",
" plt.scatter(points[:, 1], points[:, 0], s=5, zorder=2)\n",
" for i in [0, 1, 2]:\n",
" plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c='#8EB9D9', s=24, zorder=2) # c=colors[i]\n",
" plt.axis('off')\n",
" plt.title('Edge_GT_'+base_name)\n",
"\n",
" plt.subplot(1, 2, 2)\n",
" plt.imshow(img, cmap='gray')\n",
" json_path = '../../data/gnn_data/e2e/raw/{}.json'.format(base_name)\n",
" points, edge_index, labels, lights = load_data(json_path)\n",
" sp = points[edge_index[0].astype(np.int32)]\n",
" tp = points[edge_index[1].astype(np.int32)]\n",
" \n",
" for i in range(len(sp)):\n",
" plt.plot([sp[i][1], tp[i][1]], [sp[i][0], tp[i][0]], linewidth=2, c='#C0C0C0', zorder=1)\n",
" \n",
" for i in [0, 1, 2]:\n",
" plt.scatter(points[labels == i][:, 1], points[labels == i][:, 0], c='#8EB9D9', s=24, zorder=2) # c='#8EB9D9'\n",
" plt.axis('off')\n",
" plt.title('Edge_Pd_'+base_name)\n",
" plt.tight_layout()\n",
" plt.savefig('./{}_no_class_with_img.jpg'.format(base_name), bbox_inches='tight', dpi=300)\n",
" plt.close()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "649a93ff-a0ae-4ba5-a04f-6d10bd43217b",
"metadata": {},
"outputs": [],
"source": [
"plot_json('10')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e37863b2-f376-4568-a57e-47031cf98a50",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9864661-3bf7-476f-bef5-674ef9dfd7b0",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7515d855-5674-4ea3-ae52-787185cdf5db",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}