atom-predict/msunet/.ipynb_checkpoints/Vis-Test-checkpoint.ipynb

172 lines
4.6 KiB
Plaintext
Executable File

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "f7589eab-423f-48be-88cc-96348b018bc7",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import glob\n",
"import json\n",
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"\n",
"from PIL import Image\n",
"from skimage.feature import peak_local_max\n",
"from skimage.measure import label\n",
"from skimage.measure import regionprops"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "931d4ec3-4ca7-4b51-a3bd-9affccd46ecf",
"metadata": {},
"outputs": [],
"source": [
"def hough_center_detection(i, rp, labeled_img, img_size=2048):\n",
" hs, ws, he, we = rp.bbox\n",
" hs = np.clip(hs - 5, 0, img_size-1)\n",
" ws = np.clip(ws - 5, 0, img_size-1)\n",
" he = np.clip(he + 5, 0, img_size-1)\n",
" we = np.clip(we + 5, 0, img_size-1)\n",
"\n",
" m = np.array(labeled_img == rp.label, np.uint8)[hs:he, ws:we]\n",
" \n",
" cricles = cv2.HoughCircles(\n",
" m,\n",
" method = cv2.HOUGH_GRADIENT,\n",
" dp = 1,\n",
" minDist = 14,\n",
" minRadius = 5,\n",
" maxRadius = 12,\n",
" param1 = 5,\n",
" param2 = 6,\n",
" )\n",
" \n",
" if cricles is None:\n",
" return np.array([])\n",
" \n",
" if (rp.area > 400) & (cricles.shape[1] != 2):\n",
" print(i)\n",
" \n",
" centers = np.round(cricles[0][:, :2][:, ::-1] + [hs, ws])\n",
" centers = np.array(centers, np.int32)\n",
" \n",
" return centers"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "0dc26f52-7985-48c9-a7b7-6d9243dcc5a7",
"metadata": {},
"outputs": [],
"source": [
"def get_mask(probs, min_size=8):\n",
" binary = np.array(probs * 255., np.uint8)\n",
" _, binary = cv2.threshold(binary, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)\n",
" \n",
" centers = []\n",
" mask = np.zeros(binary.shape)\n",
" \n",
" labeled_img = label(binary)\n",
" rps = regionprops(labeled_img, intensity_image=probs)\n",
" \n",
" for rp in rps:\n",
" if rp.area < min_size:\n",
" continue\n",
"\n",
" h, w = np.array(np.round(rp.centroid), np.int32)\n",
" mask[h, w] = 1\n",
" \n",
" return mask"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "3d439e0d-b3a7-45ab-82e7-b34b1b5b182c",
"metadata": {},
"outputs": [],
"source": [
"def get_mask_v2(probs, min_size=8):\n",
" binary = np.array(probs * 255., np.uint8)\n",
" _, binary = cv2.threshold(binary, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)\n",
" \n",
" centers = []\n",
" mask = np.zeros(binary.shape)\n",
" \n",
" labeled_img = label(binary)\n",
" rps = regionprops(labeled_img, intensity_image=probs)\n",
"\n",
" for i, rp in enumerate(rps):\n",
" if rp.area < 32:\n",
" continue\n",
"\n",
" rp_centers = hough_center_detection(i, rp, labeled_img)\n",
"\n",
" if len(rp_centers) == 0:\n",
" h, w = np.array(np.round(rp.centroid), np.int32)\n",
" mask[h, w] = 1\n",
" else:\n",
" for h, w in rp_centers:\n",
" mask[h, w] = 1\n",
"\n",
" return mask"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "86ba3768-f686-4995-a8d6-77eb675c6702",
"metadata": {},
"outputs": [],
"source": [
"colors = ['red', 'yellow', 'blue']\n",
"\n",
"with open('./logs/0/version_0/test.json') as f:\n",
" data = json.load(f)\n",
" \n",
"img_path = np.array(data['img_path'])\n",
"pred = np.array(data['pred'])\n",
"lb = np.array(data['label'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fbc6c01-7380-4f80-9326-15aa48aa1c2c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}