atom-predict/egnn_v2/.ipynb_checkpoints/CF-checkpoint.ipynb

303 lines
9.7 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "1d6c53c4-feee-4559-9d52-ffe80f475649",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from core.data import load_data\n",
"from PIL import Image\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cbd90651-c3c0-46ce-8f60-0766aedce803",
"metadata": {},
"outputs": [],
"source": [
"def box_iou(box1, box2, eps=1e-7):\n",
" \"\"\"\n",
" Return intersection-over-union (Jaccard index) of boxes.\n",
" Both sets of boxes are expected to be in (x1, y1, x2, y2) format.\n",
" Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py\n",
"\n",
" Arguments:\n",
" box1 (Tensor[N, 4])\n",
" box2 (Tensor[M, 4])\n",
" eps\n",
"\n",
" Returns:\n",
" iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2\n",
" \"\"\"\n",
"\n",
" # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)\n",
" (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)\n",
" inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)\n",
"\n",
" # IoU = inter / (area1 + area2 - inter)\n",
" return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "80c7f206-e69d-4f0d-8fe2-c8ceb598d777",
"metadata": {},
"outputs": [],
"source": [
"class ConfusionMatrix:\n",
" # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix\n",
" def __init__(self, nc, conf=0.25, iou_thres=0.5):\n",
" self.matrix = np.zeros((nc + 1, nc + 1))\n",
" self.nc = nc # number of classes\n",
" self.conf = conf # 类别置信度\n",
" self.iou_thres = iou_thres # IoU置信度\n",
"\n",
" def process_batch(self, detections, labels):\n",
" \"\"\"\n",
" Return intersection-ove-unionr (Jaccard index) of boxes.\n",
" Both sets of boxes are expected to be in (x1, y1, x2, y2) format.\n",
" Arguments:\n",
" detections (Array[N, 6]), x1, y1, x2, y2, conf, class\n",
" labels (Array[M, 5]), class, x1, y1, x2, y2\n",
" Returns:\n",
" None, updates confusion matrix accordingly\n",
" \"\"\"\n",
" if detections is None:\n",
" gt_classes = labels.int()\n",
" for gc in gt_classes:\n",
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
" return\n",
"\n",
" detections = detections[detections[:, 4] > self.conf] # 小于该conf认为为背景\n",
" gt_classes = labels[:, 0].int() # 实际类别\n",
" detection_classes = detections[:, 5].int() # 预测类别\n",
" iou = box_iou(labels[:, 1:], detections[:, :4]) # 计算所有结果的IoU\n",
"\n",
" x = torch.where(iou > self.iou_thres) # 根据IoU匹配结果,返回满足条件的索引 x(dim0), (dim1)\n",
" if x[0].shape[0]: # x[0]存在为True的索引(gt索引), x[1]当前所有下True的索引(dt索引)\n",
" # shape:[n, 3] 3->[label, detect, iou]\n",
" matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()\n",
" if x[0].shape[0] > 1:\n",
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
" matches = matches[np.unique(matches[:, 1], return_index=True)[1]] # 若一个dt匹配多个gt,保留IoU最高的gt匹配结果\n",
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
" matches = matches[np.unique(matches[:, 0], return_index=True)[1]] # 若一个gt匹配多个dt,保留IoU最高的dt匹配结果\n",
" else:\n",
" matches = np.zeros((0, 3))\n",
"\n",
" n = matches.shape[0] > 0 # 是否存在和gt匹配成功的dt\n",
" m0, m1, _ = matches.transpose().astype(int) # m0:gt索引 m1:dt索引\n",
" for i, gc in enumerate(gt_classes): # 实际的结果\n",
" j = m0 == i # 预测为该目标的预测结果序号\n",
" if n and sum(j) == 1: # 该实际结果预测成功\n",
" self.matrix[detection_classes[m1[j]], gc] += 1 # 预测为目标,且实际为目标\n",
" else: # 该实际结果预测失败\n",
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
"\n",
" if n:\n",
" for i, dc in enumerate(detection_classes): # 对预测结果处理\n",
" if not any(m1 == i): # 若该预测结果没有和实际结果匹配\n",
" self.matrix[dc, self.nc] += 1 # 预测为目标,但实际为背景\n",
"\n",
" def tp_fp(self):\n",
" tp = self.matrix.diagonal() # true positives\n",
" fp = self.matrix.sum(1) - tp # false positives\n",
" # fn = self.matrix.sum(0) - tp # false negatives (missed detections)\n",
" return tp[:-1], fp[:-1] # remove background class\n",
"\n",
" def print(self):\n",
" for i in range(self.nc + 1):\n",
" print(' '.join(map(str, self.matrix[i])))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c1a01351-da0b-4cf7-92fd-094966af2907",
"metadata": {},
"outputs": [],
"source": [
"json_path = '/home/andrewtal/Workspace/metrials/results/our/10.json'"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "60c50e35-7f04-4449-8d70-929efc80d66b",
"metadata": {},
"outputs": [],
"source": [
"points, edge_index, labels, _ = load_data(json_path)\n",
"\n",
"mask_pd = np.zeros((2048, 2048))\n",
"mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
"mask_pd = np.array(mask_pd, np.uint8)\n",
"\n",
"mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "da527085-7cac-4cf1-92eb-b1c3b2617566",
"metadata": {},
"outputs": [],
"source": [
"# detections (Array[N, 6]), x1, y1, x2, y2, conf, class\n",
"# labels (Array[M, 5]), class, x1, y1, x2, y2"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a694e738-4cc0-4e39-8b3d-5099e73b09f4",
"metadata": {},
"outputs": [],
"source": [
"ppd = np.array(np.where(mask_pd != 0)).T\n",
"detections = []\n",
"for h, w in ppd:\n",
" cl = mask_pd[h, w]\n",
" x1 = h - 10.\n",
" x2 = h + 10.\n",
" y1 = w - 10.\n",
" y2 = w + 10.\n",
" conf = 1.\n",
" \n",
" detections += [[x1, y1, x2, y2, conf, cl]]\n",
" \n",
"detections = np.array(detections)\n",
"detections[:, :4] += 10\n",
"detections = torch.FloatTensor(detections)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "9a1d5542-81cf-45b1-95bb-e3a16c74586a",
"metadata": {},
"outputs": [],
"source": [
"labels = []\n",
"pgt = np.array(np.where(mask_gt != 0)).T\n",
"\n",
"for h, w in pgt:\n",
" cl = mask_gt[h, w]\n",
" x1 = h - 10.\n",
" x2 = h + 10.\n",
" y1 = w - 10.\n",
" y2 = w + 10.\n",
" \n",
" labels += [[cl, x1, y1, x2, y2]]\n",
" \n",
"labels = np.array(labels)\n",
"labels[:, 1:] += 10\n",
"labels = torch.FloatTensor(labels)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "4aec5d29-cb07-4e1c-83a5-98cf226273dc",
"metadata": {},
"outputs": [],
"source": [
"cm = ConfusionMatrix(nc=3)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "89bfb695-710b-43b4-807f-95f52b7f93bb",
"metadata": {},
"outputs": [],
"source": [
"cm.process_batch(detections, labels)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0572b149-84c4-451f-9d5e-dcb5843c270e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 0, 0],\n",
" [ 0, -19, 4, 10],\n",
" [ 0, 2, 99, 4],\n",
" [ 0, 15, 6, 86]], dtype=int8)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array(cm.matrix, np.int8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8043bcd-b10f-4731-8588-915b83070fab",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "f59853b6-bbfc-497a-9714-62339514b024",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7ff9b3b7-846a-45e4-84f7-5c0fd0457c6d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "31da26eb-3abc-42cd-94ca-7c1d43025227",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "cmae",
"language": "python",
"name": "cmae"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}