303 lines
9.7 KiB
Plaintext
303 lines
9.7 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"id": "1d6c53c4-feee-4559-9d52-ffe80f475649",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from core.data import load_data\n",
|
||
"from PIL import Image\n",
|
||
"import torch"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "cbd90651-c3c0-46ce-8f60-0766aedce803",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def box_iou(box1, box2, eps=1e-7):\n",
|
||
" \"\"\"\n",
|
||
" Return intersection-over-union (Jaccard index) of boxes.\n",
|
||
" Both sets of boxes are expected to be in (x1, y1, x2, y2) format.\n",
|
||
" Based on https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py\n",
|
||
"\n",
|
||
" Arguments:\n",
|
||
" box1 (Tensor[N, 4])\n",
|
||
" box2 (Tensor[M, 4])\n",
|
||
" eps\n",
|
||
"\n",
|
||
" Returns:\n",
|
||
" iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2\n",
|
||
" \"\"\"\n",
|
||
"\n",
|
||
" # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)\n",
|
||
" (a1, a2), (b1, b2) = box1.unsqueeze(1).chunk(2, 2), box2.unsqueeze(0).chunk(2, 2)\n",
|
||
" inter = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)\n",
|
||
"\n",
|
||
" # IoU = inter / (area1 + area2 - inter)\n",
|
||
" return inter / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - inter + eps)\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "80c7f206-e69d-4f0d-8fe2-c8ceb598d777",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ConfusionMatrix:\n",
|
||
" # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix\n",
|
||
" def __init__(self, nc, conf=0.25, iou_thres=0.5):\n",
|
||
" self.matrix = np.zeros((nc + 1, nc + 1))\n",
|
||
" self.nc = nc # number of classes\n",
|
||
" self.conf = conf # 类别置信度\n",
|
||
" self.iou_thres = iou_thres # IoU置信度\n",
|
||
"\n",
|
||
" def process_batch(self, detections, labels):\n",
|
||
" \"\"\"\n",
|
||
" Return intersection-ove-unionr (Jaccard index) of boxes.\n",
|
||
" Both sets of boxes are expected to be in (x1, y1, x2, y2) format.\n",
|
||
" Arguments:\n",
|
||
" detections (Array[N, 6]), x1, y1, x2, y2, conf, class\n",
|
||
" labels (Array[M, 5]), class, x1, y1, x2, y2\n",
|
||
" Returns:\n",
|
||
" None, updates confusion matrix accordingly\n",
|
||
" \"\"\"\n",
|
||
" if detections is None:\n",
|
||
" gt_classes = labels.int()\n",
|
||
" for gc in gt_classes:\n",
|
||
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
|
||
" return\n",
|
||
"\n",
|
||
" detections = detections[detections[:, 4] > self.conf] # 小于该conf认为为背景\n",
|
||
" gt_classes = labels[:, 0].int() # 实际类别\n",
|
||
" detection_classes = detections[:, 5].int() # 预测类别\n",
|
||
" iou = box_iou(labels[:, 1:], detections[:, :4]) # 计算所有结果的IoU\n",
|
||
"\n",
|
||
" x = torch.where(iou > self.iou_thres) # 根据IoU匹配结果,返回满足条件的索引 x(dim0), (dim1)\n",
|
||
" if x[0].shape[0]: # x[0]:存在为True的索引(gt索引), x[1]当前所有下True的索引(dt索引)\n",
|
||
" # shape:[n, 3] 3->[label, detect, iou]\n",
|
||
" matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()\n",
|
||
" if x[0].shape[0] > 1:\n",
|
||
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
|
||
" matches = matches[np.unique(matches[:, 1], return_index=True)[1]] # 若一个dt匹配多个gt,保留IoU最高的gt匹配结果\n",
|
||
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
|
||
" matches = matches[np.unique(matches[:, 0], return_index=True)[1]] # 若一个gt匹配多个dt,保留IoU最高的dt匹配结果\n",
|
||
" else:\n",
|
||
" matches = np.zeros((0, 3))\n",
|
||
"\n",
|
||
" n = matches.shape[0] > 0 # 是否存在和gt匹配成功的dt\n",
|
||
" m0, m1, _ = matches.transpose().astype(int) # m0:gt索引 m1:dt索引\n",
|
||
" for i, gc in enumerate(gt_classes): # 实际的结果\n",
|
||
" j = m0 == i # 预测为该目标的预测结果序号\n",
|
||
" if n and sum(j) == 1: # 该实际结果预测成功\n",
|
||
" self.matrix[detection_classes[m1[j]], gc] += 1 # 预测为目标,且实际为目标\n",
|
||
" else: # 该实际结果预测失败\n",
|
||
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
|
||
"\n",
|
||
" if n:\n",
|
||
" for i, dc in enumerate(detection_classes): # 对预测结果处理\n",
|
||
" if not any(m1 == i): # 若该预测结果没有和实际结果匹配\n",
|
||
" self.matrix[dc, self.nc] += 1 # 预测为目标,但实际为背景\n",
|
||
"\n",
|
||
" def tp_fp(self):\n",
|
||
" tp = self.matrix.diagonal() # true positives\n",
|
||
" fp = self.matrix.sum(1) - tp # false positives\n",
|
||
" # fn = self.matrix.sum(0) - tp # false negatives (missed detections)\n",
|
||
" return tp[:-1], fp[:-1] # remove background class\n",
|
||
"\n",
|
||
" def print(self):\n",
|
||
" for i in range(self.nc + 1):\n",
|
||
" print(' '.join(map(str, self.matrix[i])))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "c1a01351-da0b-4cf7-92fd-094966af2907",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"json_path = '/home/andrewtal/Workspace/metrials/results/our/10.json'"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "60c50e35-7f04-4449-8d70-929efc80d66b",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"points, edge_index, labels, _ = load_data(json_path)\n",
|
||
"\n",
|
||
"mask_pd = np.zeros((2048, 2048))\n",
|
||
"mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
||
"mask_pd = np.array(mask_pd, np.uint8)\n",
|
||
"\n",
|
||
"mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"id": "da527085-7cac-4cf1-92eb-b1c3b2617566",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# detections (Array[N, 6]), x1, y1, x2, y2, conf, class\n",
|
||
"# labels (Array[M, 5]), class, x1, y1, x2, y2"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"id": "a694e738-4cc0-4e39-8b3d-5099e73b09f4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"ppd = np.array(np.where(mask_pd != 0)).T\n",
|
||
"detections = []\n",
|
||
"for h, w in ppd:\n",
|
||
" cl = mask_pd[h, w]\n",
|
||
" x1 = h - 10.\n",
|
||
" x2 = h + 10.\n",
|
||
" y1 = w - 10.\n",
|
||
" y2 = w + 10.\n",
|
||
" conf = 1.\n",
|
||
" \n",
|
||
" detections += [[x1, y1, x2, y2, conf, cl]]\n",
|
||
" \n",
|
||
"detections = np.array(detections)\n",
|
||
"detections[:, :4] += 10\n",
|
||
"detections = torch.FloatTensor(detections)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"id": "9a1d5542-81cf-45b1-95bb-e3a16c74586a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"labels = []\n",
|
||
"pgt = np.array(np.where(mask_gt != 0)).T\n",
|
||
"\n",
|
||
"for h, w in pgt:\n",
|
||
" cl = mask_gt[h, w]\n",
|
||
" x1 = h - 10.\n",
|
||
" x2 = h + 10.\n",
|
||
" y1 = w - 10.\n",
|
||
" y2 = w + 10.\n",
|
||
" \n",
|
||
" labels += [[cl, x1, y1, x2, y2]]\n",
|
||
" \n",
|
||
"labels = np.array(labels)\n",
|
||
"labels[:, 1:] += 10\n",
|
||
"labels = torch.FloatTensor(labels)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"id": "4aec5d29-cb07-4e1c-83a5-98cf226273dc",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cm = ConfusionMatrix(nc=3)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"id": "89bfb695-710b-43b4-807f-95f52b7f93bb",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"cm.process_batch(detections, labels)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"id": "0572b149-84c4-451f-9d5e-dcb5843c270e",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"array([[ 0, 0, 0, 0],\n",
|
||
" [ 0, -19, 4, 10],\n",
|
||
" [ 0, 2, 99, 4],\n",
|
||
" [ 0, 15, 6, 86]], dtype=int8)"
|
||
]
|
||
},
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"np.array(cm.matrix, np.int8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f8043bcd-b10f-4731-8588-915b83070fab",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "f59853b6-bbfc-497a-9714-62339514b024",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "7ff9b3b7-846a-45e4-84f7-5c0fd0457c6d",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "31da26eb-3abc-42cd-94ca-7c1d43025227",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "cmae",
|
||
"language": "python",
|
||
"name": "cmae"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.16"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|