import os import cv2 import glob import json import numpy as np import pandas as pd import matplotlib.pyplot as plt from tqdm import tqdm from PIL import Image #from core.data import load_data_plot from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, confusion_matrix def load_data(json_path): data_dict = { '0': 0, '1': 1, '2': 2, } img_path = json_path.replace('.json', '.jpg') img = np.array(Image.open(img_path)) with open(json_path) as f: json_data = json.load(f) points = np.array([item['points'][0][::-1] for item in json_data['shapes']], np.int32) try: labels = np.array([data_dict[item['label']] for item in json_data['shapes']], np.uint8) except: labels = np.array([0 for item in json_data['shapes']], np.uint8) return points, labels def plot2(img_path, json_path, save_path=None,psize=20): c = ['red', 'green', 'yellow'] img = cv2.imread(img_path, 0) points, labels = load_data(json_path) mask_pd = np.zeros(img.shape) mask_pd[points[:, 0], points[:, 1]] = labels + 1 mask_pd = np.array(mask_pd, np.uint8) plt.figure(figsize=(9, 9)) plt.imshow(img, cmap='gray') for i in range(3): h, w = np.where(mask_pd == i + 1) plt.scatter(w, h, s=psize, c=c[i]) plt.axis('off') # plt.tight_layout() # plt.show() # 确保结果目录存在 # 获取原图像的文件名 file_name = os.path.basename(img_path) result_path = os.path.join(save_path, "predict_"+ file_name) # 保存图像 plt.savefig(result_path) plt.close() # json_path = './result/post_process/10.json' # img_path = './result/post_process/10.jpg' # # plot2(img_path, json_path)