321 lines
10 KiB
Python
321 lines
10 KiB
Python
import argparse
|
||
import os
|
||
import zipfile
|
||
import glob
|
||
from typing import List
|
||
from pre import process_slide
|
||
from test_pl_unequal import predict
|
||
from typing import Dict
|
||
from typing import Tuple
|
||
from utils.labelme import save_pred_to_json
|
||
from plot import plot2
|
||
import os
|
||
import math
|
||
from PIL import Image
|
||
import numpy as np
|
||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||
|
||
import os
|
||
from typing import Dict, List
|
||
from PIL import Image
|
||
from collections import defaultdict
|
||
|
||
|
||
|
||
# # 增加参数一个是数据集的路径,另外一个是保存的路径
|
||
# parser = argparse.ArgumentParser(description='Process some integers.')
|
||
# parser.add_argument('--data_path', type=str, default='./data/test', help='path to dataset')
|
||
# parser.add_argument('--save_path', type=str, default='./data/test_processed', help='path to save processed data')
|
||
# args = parser.parse_args()
|
||
|
||
def zip_file(data_path: str, save_path: str) -> List[str]:
|
||
# 创建保存路径
|
||
if not os.path.exists(save_path):
|
||
os.makedirs(save_path)
|
||
|
||
# 解压 zip 文件
|
||
with zipfile.ZipFile(data_path, 'r') as zip_ref:
|
||
zip_ref.extractall(save_path)
|
||
|
||
# 查找解压路径中的所有 jpg 和 png 文件
|
||
image_paths = glob.glob(os.path.join(save_path, '**', '*.jpg'), recursive=True)
|
||
image_paths += glob.glob(os.path.join(save_path, '**', '*.png'), recursive=True)
|
||
|
||
return image_paths
|
||
|
||
def get_cmd_path() ->str:
|
||
# 获取当前脚本的绝对路径
|
||
script_path = os.path.abspath(__file__)
|
||
# 获取当前脚本所在的目录
|
||
script_dir = os.path.dirname(script_path)
|
||
|
||
print(f"当前脚本的路径: {script_path}")
|
||
print(f"当前脚本所在的目录: {script_dir}")
|
||
|
||
return script_dir
|
||
|
||
def create_save_path(base_path: str) -> Dict[str, str]:
|
||
# 定义需要创建的目录结构
|
||
paths = {
|
||
"dataset": os.path.join(base_path, "predict_dataset"),
|
||
"pre_process": os.path.join(base_path, "predict_pre_process"),
|
||
"predict": os.path.join(base_path, "predict_result"),
|
||
"post_process": os.path.join(base_path, "predict_post_process"),
|
||
"plot_view": os.path.join(base_path, "predict_result_view"),
|
||
}
|
||
|
||
# 创建目录
|
||
for path in paths.values():
|
||
os.makedirs(path, exist_ok=True)
|
||
|
||
return paths
|
||
|
||
|
||
|
||
def find_image_json_pairs(directory: str) -> List[Dict[str, str]]:
|
||
"""
|
||
Finds pairs of JPG and JSON files in the specified directory.
|
||
|
||
Args:
|
||
- directory (str): The directory path where files are located.
|
||
|
||
Returns:
|
||
- List[Dict[str, str]]: A list of dictionaries where each dictionary contains
|
||
a pair of JPG and JSON filenames with absolute paths as {'jpg': '/absolute/path/to/filename.jpg', 'json': '/absolute/path/to/filename.json'}.
|
||
"""
|
||
pairs = []
|
||
|
||
# Get all files in the directory
|
||
files = os.listdir(directory)
|
||
|
||
# Dictionary to track found JSON files
|
||
file_dict: Dict[str, bool] = {}
|
||
|
||
# Iterate through all files
|
||
for filename in files:
|
||
# Split filename and extension
|
||
name, ext = os.path.splitext(filename)
|
||
|
||
# If it's a JPG file, check for corresponding JSON file
|
||
if ext.lower() == '.jpg':
|
||
json_file = os.path.join(directory, name + '.json')
|
||
if os.path.exists(json_file) and os.path.isfile(json_file):
|
||
# Add found pair to the list with absolute paths
|
||
jpg_path = os.path.abspath(os.path.join(directory, name + '.jpg'))
|
||
json_path = os.path.abspath(json_file)
|
||
pairs.append({'jpg': jpg_path, 'json': json_path})
|
||
# Mark JSON file as found
|
||
file_dict[name + '.json'] = True
|
||
|
||
# If it's a JSON file and not already processed, check for corresponding JPG file
|
||
elif ext.lower() == '.json' and filename not in file_dict:
|
||
jpg_file = os.path.join(directory, name + '.jpg')
|
||
if os.path.exists(jpg_file) and os.path.isfile(jpg_file):
|
||
# Add found pair to the list with absolute paths
|
||
jpg_path = os.path.abspath(jpg_file)
|
||
json_path = os.path.abspath(os.path.join(directory, name + '.json'))
|
||
pairs.append({'jpg': jpg_path, 'json': json_path})
|
||
|
||
return pairs
|
||
|
||
|
||
# if __name__ == '__main__':
|
||
# cmd_path = get_cmd_path()
|
||
# save_path = create_save_path(cmd_path)
|
||
# print("save_path:", save_path)
|
||
# image_paths = zip_file(args.data_path, save_path["dataset"])
|
||
# for path in image_paths:
|
||
# print("image path:", path)
|
||
#
|
||
# # 预处理
|
||
# print("开始预处理")
|
||
# for path in image_paths:
|
||
# process_slide(path, save_path["pre_process"])
|
||
# print("预处理结束")
|
||
#
|
||
# # 预测
|
||
# print("开始预测")
|
||
#
|
||
# files = np.array(glob.glob(save_path["pre_process_img"] + "/*.png")).tolist()
|
||
# filename = predict("", files, save_path["predict"])
|
||
# print("预测结束")
|
||
#
|
||
# # 后处理
|
||
# print("开始后处理")
|
||
# save_pred_to_json(filename, save_path["post_process"])
|
||
# print("后处理结束")
|
||
#
|
||
# # 可视化
|
||
# print("开始可视化")
|
||
# pairs = find_image_json_pairs(save_path["post_process"])
|
||
# print("pairs:", pairs)
|
||
# for pair in pairs:
|
||
# plot2(pair['jpg'], pair['json'], save_path["plot_view"])
|
||
# print("可视化结束")
|
||
#
|
||
|
||
|
||
|
||
# class Options(BaseModel):
|
||
# data_path: InputFilePath = Field(..., ftypes=['.zip'], description="测试的图片")
|
||
# model_path: Optional[InputFilePath] = Field(ftypes=['.ckpt'], description="使用的模型")
|
||
# output_dir: OutputDirectory = Field(
|
||
# default="./result"
|
||
# ) # default will be override after online
|
||
#
|
||
#
|
||
# class GlobalOptions(Options, BaseModel):
|
||
# ...
|
||
|
||
def resize_to_nearest_multiple_of_32(image_path, output_dir):
|
||
# 打开图片
|
||
image = Image.open(image_path)
|
||
width, height = image.size
|
||
new_width = math.ceil(width / 32) * 32
|
||
new_height = math.ceil(height / 32) * 32
|
||
resized_image = image.resize((new_width, new_height))
|
||
|
||
# 打印调整前后的尺寸
|
||
print(f"Original size: {width}x{height}")
|
||
print(f"Resized size: {new_width}x{new_height}")
|
||
|
||
# 提取原始文件名并创建新的输出路径
|
||
file_name = os.path.basename(image_path)
|
||
output_path = os.path.join(output_dir, file_name)
|
||
|
||
# 保存调整后的图片
|
||
resized_image.save(output_path)
|
||
return output_path
|
||
|
||
|
||
|
||
|
||
|
||
def group_images_by_size(output_dir: str) -> Dict[str, List[str]]:
|
||
"""
|
||
将指定目录中的图片按尺寸分组。
|
||
|
||
Args:
|
||
output_dir (str): 存放图片的目录路径。
|
||
|
||
Returns:
|
||
Dict[str, List[str]]: 按尺寸分组的图片路径字典,键为尺寸 '宽x高' 的字符串格式,值为图片路径列表。
|
||
"""
|
||
# 字典用于存储不同尺寸的图片列表
|
||
size_to_images: defaultdict[str, List[str]] = defaultdict(list)
|
||
|
||
# 遍历目录中的所有文件
|
||
for filename in os.listdir(output_dir):
|
||
if filename.lower().endswith(('png', 'jpg', 'jpeg')):
|
||
file_path = os.path.join(output_dir, filename)
|
||
image = Image.open(file_path)
|
||
width, height = image.size
|
||
size_key = f"{width}x{height}"
|
||
size_to_images[size_key].append(file_path)
|
||
|
||
# 转换 defaultdict 为普通字典
|
||
grouped_images: Dict[str, List[str]] = dict(size_to_images)
|
||
return grouped_images
|
||
|
||
|
||
|
||
|
||
def predict_and_plot(model_path: str, img_paths: str, output_dir: str, is_after_train: bool = False) ->Dict[str, str]:
|
||
print("data path:", img_paths)
|
||
#print("model_path:", opts.model_path.get_full_path())
|
||
print("output_dir:", output_dir)
|
||
|
||
save_path = create_save_path(output_dir)
|
||
print("save_path:", save_path)
|
||
if not is_after_train:
|
||
image_paths = zip_file(img_paths, save_path["dataset"])
|
||
else:
|
||
image_paths = glob.glob(img_paths + "/**/*.jpg", recursive=True)
|
||
|
||
for path in image_paths:
|
||
print("image path:", path)
|
||
|
||
# 预处理
|
||
print("开始预处理")
|
||
for path in image_paths:
|
||
resize_to_nearest_multiple_of_32(path, save_path["pre_process"])
|
||
print("预处理结束")
|
||
|
||
grouped_images = group_images_by_size(save_path["pre_process"])
|
||
# 预测
|
||
print("开始预测")
|
||
print("预测使用的模型路径:", model_path)
|
||
filenames = []
|
||
batch = 0
|
||
for size, image_paths in grouped_images.items():
|
||
batch += 1
|
||
print("开始预测图片的size:", size)
|
||
filename = predict(model_path, image_paths, save_path["predict"], batch)
|
||
filenames.append(filename)
|
||
print("预测结束")
|
||
|
||
# 后处理
|
||
print("开始后处理")
|
||
for filename in filenames:
|
||
save_pred_to_json(filename, save_path["post_process"])
|
||
print("后处理结束")
|
||
|
||
# 可视化
|
||
print("开始可视化")
|
||
pairs = find_image_json_pairs(save_path["post_process"])
|
||
print("pairs:", pairs)
|
||
for pair in pairs:
|
||
plot2(pair['jpg'], pair['json'], save_path["plot_view"])
|
||
print("可视化结束")
|
||
|
||
# 生成报告
|
||
#generate_report(save_path, output_dir)
|
||
return save_path
|
||
|
||
|
||
|
||
def remove_prefix(full_path: str, prefix: str) -> str:
|
||
# 计算相对路径
|
||
relative_path = os.path.relpath(full_path, prefix)
|
||
return relative_path
|
||
|
||
def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
|
||
img_elements = []
|
||
#原始图片在遍历,路径在save_path["dataset"]
|
||
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
|
||
img_elements.append(AutoReportElement(
|
||
path=remove_prefix(img_path, output_dir),
|
||
title=img_path.split("/")[-1],
|
||
description=f'原始图片',
|
||
))
|
||
|
||
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
|
||
|
||
img_elements = []
|
||
#预测结果在遍历,
|
||
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
|
||
img_elements.append(AutoReportElement(
|
||
path=remove_prefix(img_path, output_dir),
|
||
title=img_path.split('/')[-1],
|
||
description=f'预测结果',
|
||
) )
|
||
|
||
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
|
||
|
||
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section])
|
||
report.save(output_dir)
|
||
|
||
|
||
# def to_parser():
|
||
# return to_runner(
|
||
# GlobalOptions,
|
||
# main,
|
||
# version='0.1.0',
|
||
# exception_handler=default_minimal_exception_handler,
|
||
# )
|
||
#
|
||
#
|
||
# if __name__ == '__main__':
|
||
# to_parser()(sys.argv[1:])
|
||
#
|