atom-predict/msunet/predict.py

321 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
import os
import math
from PIL import Image
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import os
from typing import Dict, List
from PIL import Image
from collections import defaultdict
# # 增加参数一个是数据集的路径,另外一个是保存的路径
# parser = argparse.ArgumentParser(description='Process some integers.')
# parser.add_argument('--data_path', type=str, default='./data/test', help='path to dataset')
# parser.add_argument('--save_path', type=str, default='./data/test_processed', help='path to save processed data')
# args = parser.parse_args()
def zip_file(data_path: str, save_path: str) -> List[str]:
# 创建保存路径
if not os.path.exists(save_path):
os.makedirs(save_path)
# 解压 zip 文件
with zipfile.ZipFile(data_path, 'r') as zip_ref:
zip_ref.extractall(save_path)
# 查找解压路径中的所有 jpg 和 png 文件
image_paths = glob.glob(os.path.join(save_path, '**', '*.jpg'), recursive=True)
image_paths += glob.glob(os.path.join(save_path, '**', '*.png'), recursive=True)
return image_paths
def get_cmd_path() ->str:
# 获取当前脚本的绝对路径
script_path = os.path.abspath(__file__)
# 获取当前脚本所在的目录
script_dir = os.path.dirname(script_path)
print(f"当前脚本的路径: {script_path}")
print(f"当前脚本所在的目录: {script_dir}")
return script_dir
def create_save_path(base_path: str) -> Dict[str, str]:
# 定义需要创建的目录结构
paths = {
"dataset": os.path.join(base_path, "predict_dataset"),
"pre_process": os.path.join(base_path, "predict_pre_process"),
"predict": os.path.join(base_path, "predict_result"),
"post_process": os.path.join(base_path, "predict_post_process"),
"plot_view": os.path.join(base_path, "predict_result_view"),
}
# 创建目录
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def find_image_json_pairs(directory: str) -> List[Dict[str, str]]:
"""
Finds pairs of JPG and JSON files in the specified directory.
Args:
- directory (str): The directory path where files are located.
Returns:
- List[Dict[str, str]]: A list of dictionaries where each dictionary contains
a pair of JPG and JSON filenames with absolute paths as {'jpg': '/absolute/path/to/filename.jpg', 'json': '/absolute/path/to/filename.json'}.
"""
pairs = []
# Get all files in the directory
files = os.listdir(directory)
# Dictionary to track found JSON files
file_dict: Dict[str, bool] = {}
# Iterate through all files
for filename in files:
# Split filename and extension
name, ext = os.path.splitext(filename)
# If it's a JPG file, check for corresponding JSON file
if ext.lower() == '.jpg':
json_file = os.path.join(directory, name + '.json')
if os.path.exists(json_file) and os.path.isfile(json_file):
# Add found pair to the list with absolute paths
jpg_path = os.path.abspath(os.path.join(directory, name + '.jpg'))
json_path = os.path.abspath(json_file)
pairs.append({'jpg': jpg_path, 'json': json_path})
# Mark JSON file as found
file_dict[name + '.json'] = True
# If it's a JSON file and not already processed, check for corresponding JPG file
elif ext.lower() == '.json' and filename not in file_dict:
jpg_file = os.path.join(directory, name + '.jpg')
if os.path.exists(jpg_file) and os.path.isfile(jpg_file):
# Add found pair to the list with absolute paths
jpg_path = os.path.abspath(jpg_file)
json_path = os.path.abspath(os.path.join(directory, name + '.json'))
pairs.append({'jpg': jpg_path, 'json': json_path})
return pairs
# if __name__ == '__main__':
# cmd_path = get_cmd_path()
# save_path = create_save_path(cmd_path)
# print("save_path:", save_path)
# image_paths = zip_file(args.data_path, save_path["dataset"])
# for path in image_paths:
# print("image path:", path)
#
# # 预处理
# print("开始预处理")
# for path in image_paths:
# process_slide(path, save_path["pre_process"])
# print("预处理结束")
#
# # 预测
# print("开始预测")
#
# files = np.array(glob.glob(save_path["pre_process_img"] + "/*.png")).tolist()
# filename = predict("", files, save_path["predict"])
# print("预测结束")
#
# # 后处理
# print("开始后处理")
# save_pred_to_json(filename, save_path["post_process"])
# print("后处理结束")
#
# # 可视化
# print("开始可视化")
# pairs = find_image_json_pairs(save_path["post_process"])
# print("pairs:", pairs)
# for pair in pairs:
# plot2(pair['jpg'], pair['json'], save_path["plot_view"])
# print("可视化结束")
#
# class Options(BaseModel):
# data_path: InputFilePath = Field(..., ftypes=['.zip'], description="测试的图片")
# model_path: Optional[InputFilePath] = Field(ftypes=['.ckpt'], description="使用的模型")
# output_dir: OutputDirectory = Field(
# default="./result"
# ) # default will be override after online
#
#
# class GlobalOptions(Options, BaseModel):
# ...
def resize_to_nearest_multiple_of_32(image_path, output_dir):
# 打开图片
image = Image.open(image_path)
width, height = image.size
new_width = math.ceil(width / 32) * 32
new_height = math.ceil(height / 32) * 32
resized_image = image.resize((new_width, new_height))
# 打印调整前后的尺寸
print(f"Original size: {width}x{height}")
print(f"Resized size: {new_width}x{new_height}")
# 提取原始文件名并创建新的输出路径
file_name = os.path.basename(image_path)
output_path = os.path.join(output_dir, file_name)
# 保存调整后的图片
resized_image.save(output_path)
return output_path
def group_images_by_size(output_dir: str) -> Dict[str, List[str]]:
"""
将指定目录中的图片按尺寸分组。
Args:
output_dir (str): 存放图片的目录路径。
Returns:
Dict[str, List[str]]: 按尺寸分组的图片路径字典,键为尺寸 '宽x高' 的字符串格式,值为图片路径列表。
"""
# 字典用于存储不同尺寸的图片列表
size_to_images: defaultdict[str, List[str]] = defaultdict(list)
# 遍历目录中的所有文件
for filename in os.listdir(output_dir):
if filename.lower().endswith(('png', 'jpg', 'jpeg')):
file_path = os.path.join(output_dir, filename)
image = Image.open(file_path)
width, height = image.size
size_key = f"{width}x{height}"
size_to_images[size_key].append(file_path)
# 转换 defaultdict 为普通字典
grouped_images: Dict[str, List[str]] = dict(size_to_images)
return grouped_images
def predict_and_plot(model_path: str, img_paths: str, output_dir: str, is_after_train: bool = False) ->Dict[str, str]:
print("data path:", img_paths)
#print("model_path:", opts.model_path.get_full_path())
print("output_dir:", output_dir)
save_path = create_save_path(output_dir)
print("save_path:", save_path)
if not is_after_train:
image_paths = zip_file(img_paths, save_path["dataset"])
else:
image_paths = glob.glob(img_paths + "/**/*.jpg", recursive=True)
for path in image_paths:
print("image path:", path)
# 预处理
print("开始预处理")
for path in image_paths:
resize_to_nearest_multiple_of_32(path, save_path["pre_process"])
print("预处理结束")
grouped_images = group_images_by_size(save_path["pre_process"])
# 预测
print("开始预测")
print("预测使用的模型路径:", model_path)
filenames = []
batch = 0
for size, image_paths in grouped_images.items():
batch += 1
print("开始预测图片的size:", size)
filename = predict(model_path, image_paths, save_path["predict"], batch)
filenames.append(filename)
print("预测结束")
# 后处理
print("开始后处理")
for filename in filenames:
save_pred_to_json(filename, save_path["post_process"])
print("后处理结束")
# 可视化
print("开始可视化")
pairs = find_image_json_pairs(save_path["post_process"])
print("pairs:", pairs)
for pair in pairs:
plot2(pair['jpg'], pair['json'], save_path["plot_view"])
print("可视化结束")
# 生成报告
#generate_report(save_path, output_dir)
return save_path
def remove_prefix(full_path: str, prefix: str) -> str:
# 计算相对路径
relative_path = os.path.relpath(full_path, prefix)
return relative_path
def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
img_elements = []
#原始图片在遍历路径在save_path["dataset"]
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
img_elements.append(AutoReportElement(
path=remove_prefix(img_path, output_dir),
title=img_path.split("/")[-1],
description=f'原始图片',
))
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
img_elements = []
#预测结果在遍历,
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
img_elements.append(AutoReportElement(
path=remove_prefix(img_path, output_dir),
title=img_path.split('/')[-1],
description=f'预测结果',
) )
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section])
report.save(output_dir)
# def to_parser():
# return to_runner(
# GlobalOptions,
# main,
# version='0.1.0',
# exception_handler=default_minimal_exception_handler,
# )
#
#
# if __name__ == '__main__':
# to_parser()(sys.argv[1:])
#