199 lines
7.9 KiB
Plaintext
199 lines
7.9 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"id": "1d6c53c4-feee-4559-9d52-ffe80f475649",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "ModuleNotFoundError",
|
||
"evalue": "No module named 'core'",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[5], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mcore\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdata\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m load_data\n",
|
||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'core'"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from core.data import load_data"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"id": "80c7f206-e69d-4f0d-8fe2-c8ceb598d777",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"class ConfusionMatrix:\n",
|
||
" # Updated version of https://github.com/kaanakan/object_detection_confusion_matrix\n",
|
||
" def __init__(self, nc, conf=0.25, iou_thres=0.5):\n",
|
||
" self.matrix = np.zeros((nc + 1, nc + 1))\n",
|
||
" self.nc = nc # number of classes\n",
|
||
" self.conf = conf # 类别置信度\n",
|
||
" self.iou_thres = iou_thres # IoU置信度\n",
|
||
"\n",
|
||
" def process_batch(self, detections, labels):\n",
|
||
" \"\"\"\n",
|
||
" Return intersection-ove-unionr (Jaccard index) of boxes.\n",
|
||
" Both sets of boxes are expected to be in (x1, y1, x2, y2) format.\n",
|
||
" Arguments:\n",
|
||
" detections (Array[N, 6]), x1, y1, x2, y2, conf, class\n",
|
||
" labels (Array[M, 5]), class, x1, y1, x2, y2\n",
|
||
" Returns:\n",
|
||
" None, updates confusion matrix accordingly\n",
|
||
" \"\"\"\n",
|
||
" if detections is None:\n",
|
||
" gt_classes = labels.int()\n",
|
||
" for gc in gt_classes:\n",
|
||
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
|
||
" return\n",
|
||
"\n",
|
||
" detections = detections[detections[:, 4] > self.conf] # 小于该conf认为为背景\n",
|
||
" gt_classes = labels[:, 0].int() # 实际类别\n",
|
||
" detection_classes = detections[:, 5].int() # 预测类别\n",
|
||
" iou = box_iou(labels[:, 1:], detections[:, :4]) # 计算所有结果的IoU\n",
|
||
"\n",
|
||
" x = torch.where(iou > self.iou_thres) # 根据IoU匹配结果,返回满足条件的索引 x(dim0), (dim1)\n",
|
||
" if x[0].shape[0]: # x[0]:存在为True的索引(gt索引), x[1]当前所有下True的索引(dt索引)\n",
|
||
" # shape:[n, 3] 3->[label, detect, iou]\n",
|
||
" matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()\n",
|
||
" if x[0].shape[0] > 1:\n",
|
||
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
|
||
" matches = matches[np.unique(matches[:, 1], return_index=True)[1]] # 若一个dt匹配多个gt,保留IoU最高的gt匹配结果\n",
|
||
" matches = matches[matches[:, 2].argsort()[::-1]] # 根据IoU从大到小排序\n",
|
||
" matches = matches[np.unique(matches[:, 0], return_index=True)[1]] # 若一个gt匹配多个dt,保留IoU最高的dt匹配结果\n",
|
||
" else:\n",
|
||
" matches = np.zeros((0, 3))\n",
|
||
"\n",
|
||
" n = matches.shape[0] > 0 # 是否存在和gt匹配成功的dt\n",
|
||
" m0, m1, _ = matches.transpose().astype(int) # m0:gt索引 m1:dt索引\n",
|
||
" for i, gc in enumerate(gt_classes): # 实际的结果\n",
|
||
" j = m0 == i # 预测为该目标的预测结果序号\n",
|
||
" if n and sum(j) == 1: # 该实际结果预测成功\n",
|
||
" self.matrix[detection_classes[m1[j]], gc] += 1 # 预测为目标,且实际为目标\n",
|
||
" else: # 该实际结果预测失败\n",
|
||
" self.matrix[self.nc, gc] += 1 # 预测为背景,但实际为目标\n",
|
||
"\n",
|
||
" if n:\n",
|
||
" for i, dc in enumerate(detection_classes): # 对预测结果处理\n",
|
||
" if not any(m1 == i): # 若该预测结果没有和实际结果匹配\n",
|
||
" self.matrix[dc, self.nc] += 1 # 预测为目标,但实际为背景\n",
|
||
"\n",
|
||
" def tp_fp(self):\n",
|
||
" tp = self.matrix.diagonal() # true positives\n",
|
||
" fp = self.matrix.sum(1) - tp # false positives\n",
|
||
" # fn = self.matrix.sum(0) - tp # false negatives (missed detections)\n",
|
||
" return tp[:-1], fp[:-1] # remove background class\n",
|
||
"\n",
|
||
" def print(self):\n",
|
||
" for i in range(self.nc + 1):\n",
|
||
" print(' '.join(map(str, self.matrix[i])))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"id": "c1a01351-da0b-4cf7-92fd-094966af2907",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"json_path = '/home/andrewtal/Workspace/metrials/results/our/10.json'"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"id": "60c50e35-7f04-4449-8d70-929efc80d66b",
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"ename": "NameError",
|
||
"evalue": "name 'load_data' is not defined",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m points, edge_index, labels, _ \u001b[38;5;241m=\u001b[39m \u001b[43mload_data\u001b[49m(json_path)\n\u001b[1;32m 3\u001b[0m mask_pd \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mzeros((\u001b[38;5;241m2048\u001b[39m, \u001b[38;5;241m2048\u001b[39m))\n\u001b[1;32m 4\u001b[0m mask_pd[points[:, \u001b[38;5;241m0\u001b[39m], points[:, \u001b[38;5;241m1\u001b[39m]] \u001b[38;5;241m=\u001b[39m labels \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
|
||
"\u001b[0;31mNameError\u001b[0m: name 'load_data' is not defined"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"points, edge_index, labels, _ = load_data(json_path)\n",
|
||
"\n",
|
||
"mask_pd = np.zeros((2048, 2048))\n",
|
||
"mask_pd[points[:, 0], points[:, 1]] = labels + 1\n",
|
||
"mask_pd = np.array(mask_pd, np.uint8)\n",
|
||
"\n",
|
||
"mask_gt = np.array(Image.open(json_path.replace('.json', '.png')), np.uint8)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "da527085-7cac-4cf1-92eb-b1c3b2617566",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "bb76e1d6-9844-4e23-84cb-f6516d058f47",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "a694e738-4cc0-4e39-8b3d-5099e73b09f4",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "9a1d5542-81cf-45b1-95bb-e3a16c74586a",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"id": "31da26eb-3abc-42cd-94ca-7c1d43025227",
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "cmae",
|
||
"language": "python",
|
||
"name": "cmae"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.8.16"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|