211 lines
8.0 KiB
Python
211 lines
8.0 KiB
Python
import argparse
|
||
import json
|
||
import os
|
||
import zipfile
|
||
import glob
|
||
from typing import List, Dict
|
||
from pre import process_slide # Assuming this is a custom module
|
||
from test_pl_unequal import predict # Assuming this is a custom module
|
||
from utils.labelme import save_pred_to_json # Assuming this is a custom module
|
||
from plot import plot2 # Assuming this is a custom module
|
||
|
||
import numpy as np
|
||
from dp.launching.report import Report, ReportSection, AutoReportElement # Assuming this is a custom module
|
||
from pathlib import Path
|
||
import streamlit as st
|
||
|
||
|
||
def zip_file(data_path: str, save_path: str) -> List[str]:
|
||
if not os.path.exists(save_path):
|
||
os.makedirs(save_path)
|
||
with zipfile.ZipFile(data_path, 'r') as zip_ref:
|
||
zip_ref.extractall(save_path)
|
||
image_paths = glob.glob(os.path.join(save_path, '**', '*.jpg'), recursive=True)
|
||
image_paths += glob.glob(os.path.join(save_path, '**', '*.png'), recursive=True)
|
||
return image_paths
|
||
|
||
|
||
def get_cmd_path() -> str:
|
||
script_path = os.path.abspath(__file__)
|
||
script_dir = os.path.dirname(script_path)
|
||
return script_dir
|
||
|
||
|
||
def create_save_path(base_path: str) -> Dict[str, str]:
|
||
paths = {
|
||
"dataset": os.path.join(base_path, "predict_dataset"),
|
||
"predict": os.path.join(base_path, "predict_result"),
|
||
"post_process": os.path.join(base_path, "predict_post_process"),
|
||
"plot_view": os.path.join(base_path, "predict_result_view"),
|
||
}
|
||
for path in paths.values():
|
||
os.makedirs(path, exist_ok=True)
|
||
return paths
|
||
|
||
|
||
def find_image_json_pairs(directory: str) -> List[Dict[str, str]]:
|
||
pairs = []
|
||
files = os.listdir(directory)
|
||
file_dict: Dict[str, bool] = {}
|
||
for filename in files:
|
||
name, ext = os.path.splitext(filename)
|
||
if ext.lower() == '.jpg':
|
||
json_file = os.path.join(directory, name + '.json')
|
||
if os.path.exists(json_file) and os.path.isfile(json_file):
|
||
jpg_path = os.path.abspath(os.path.join(directory, name + '.jpg'))
|
||
json_path = os.path.abspath(json_file)
|
||
pairs.append({'jpg': jpg_path, 'json': json_path})
|
||
file_dict[name + '.json'] = True
|
||
elif ext.lower() == '.json' and filename not in file_dict:
|
||
jpg_file = os.path.join(directory, name + '.jpg')
|
||
if os.path.exists(jpg_file) and os.path.isfile(jpg_file):
|
||
jpg_path = os.path.abspath(jpg_file)
|
||
json_path = os.path.abspath(os.path.join(directory, name + '.json'))
|
||
pairs.append({'jpg': jpg_path, 'json': json_path})
|
||
return pairs
|
||
|
||
|
||
def predict_and_plot(model_path: str, img_paths: str, output_dir: str, progress: st.progress,
|
||
is_after_train: bool = False) -> None:
|
||
save_path = create_save_path(output_dir)
|
||
if not is_after_train:
|
||
image_paths = zip_file(img_paths, save_path["dataset"])
|
||
else:
|
||
image_paths = glob.glob(img_paths + "/**/*.jpg", recursive=True)
|
||
|
||
progress.progress(20)
|
||
|
||
# for path in image_paths:
|
||
# process_slide(path, save_path["pre_process"])
|
||
|
||
|
||
filename = predict(model_path, image_paths, save_path["predict"])
|
||
save_pred_to_json(filename, save_path["post_process"])
|
||
|
||
progress.progress(60)
|
||
|
||
pairs = find_image_json_pairs(save_path["post_process"])
|
||
for pair in pairs:
|
||
plot2(pair['jpg'], pair['json'], save_path["plot_view"])
|
||
|
||
progress.progress(80)
|
||
|
||
generate_report(save_path, output_dir)
|
||
|
||
progress.progress(100)
|
||
|
||
|
||
def remove_prefix(full_path: str, prefix: str) -> str:
|
||
return os.path.relpath(full_path, prefix)
|
||
|
||
|
||
def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
|
||
img_elements = []
|
||
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
|
||
img_elements.append(AutoReportElement(
|
||
path=remove_prefix(img_path, output_dir),
|
||
title=img_path.split("/")[-1],
|
||
description='原始图片',
|
||
))
|
||
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
|
||
img_elements = []
|
||
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
|
||
img_elements.append(AutoReportElement(
|
||
path=remove_prefix(img_path, output_dir),
|
||
title=img_path.split('/')[-1],
|
||
description='预测结果',
|
||
))
|
||
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
|
||
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section])
|
||
report.save(output_dir)
|
||
|
||
|
||
# Streamlit App
|
||
st.title("原子位置缺陷预测")
|
||
st.write("上传一个包含图片的zip文件进行预测。")
|
||
|
||
#batch_size = st.number_input("批量预测大小", value=128, min_value=128,max_value=8192)
|
||
# 上传数据文件,设置文件大小为500MB
|
||
data_path = st.file_uploader("上传数据文件 (.zip)", type=["zip"])
|
||
# mo_path 为可选文件,默认为项目根目录下model/last.ckpt
|
||
#mo_path = st.file_uploader("上传模型文件 (.ckpt)", type=["ckpt"], help="模型文件应当位于项目根目录下,文件名应当为last.ckpt")
|
||
#mo_path = st.file_uploader("上传模型文件 (.ckpt)", type=["ckpt"], help="模型文件应当位于项目根目录下,文件名应当为last.ckpt", )
|
||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")
|
||
|
||
|
||
def list_files_and_dirs_in_directory(directory: str) -> Dict[str, List[str]]:
|
||
file_dict = {}
|
||
for root, dirs, files in os.walk(directory):
|
||
relative_root = os.path.relpath(root, directory)
|
||
file_dict[relative_root] = files
|
||
return file_dict
|
||
|
||
def display_directory_structure(base_dir: str, file_dict: Dict[str, List[str]]):
|
||
st.title("目录结构")
|
||
for dir_path, files in file_dict.items():
|
||
if dir_path == ".":
|
||
st.subheader("根目录")
|
||
else:
|
||
st.subheader(dir_path)
|
||
for file in files:
|
||
full_path = os.path.join(base_dir, dir_path, file)
|
||
if st.button(file, key=full_path):
|
||
return full_path
|
||
return None
|
||
|
||
# def display_directory_structure(base_dir: str, file_dict: Dict[str, List[str]]):
|
||
# st.sidebar.title("目录结构")
|
||
# for dir_path, files in file_dict.items():
|
||
# if dir_path == ".":
|
||
# st.sidebar.subheader("根目录")
|
||
# else:
|
||
# st.sidebar.subheader(dir_path)
|
||
# for file in files:
|
||
# full_path = os.path.join(base_dir, dir_path, file)
|
||
# if st.sidebar.button(file, key=full_path):
|
||
# return full_path
|
||
# return None
|
||
#
|
||
|
||
def list_files(root_dir):
|
||
for root, dirs, files in os.walk(root_dir):
|
||
is_result = False
|
||
is_ori = False
|
||
if os.path.basename(root) == 'predict_result_view':
|
||
st.header('预测结果图片')
|
||
is_result = True
|
||
elif os.path.basename(root) == 'train_original':
|
||
st.header('原始图片')
|
||
is_ori = True
|
||
if is_result or is_ori:
|
||
for f in files:
|
||
if f.endswith('.jpg') or f.endswith('.png'):
|
||
file_path = os.path.join(root, f)
|
||
st.image(file_path, caption=f)
|
||
|
||
if data_path and st.button("运行预测"):
|
||
with open("temp.zip", "wb") as f:
|
||
f.write(data_path.getbuffer())
|
||
|
||
progress = st.progress(0)
|
||
with st.spinner("预测处理中,请稍候..."):
|
||
predict_and_plot(model_path, "temp.zip", output_dir, progress)
|
||
|
||
st.success("预测完成。以下是结果文件预览:")
|
||
|
||
list_files(output_dir)
|
||
# # 展示结果文件和图片
|
||
# files_and_dirs = list_files_and_dirs_in_directory(output_dir)
|
||
# selected_file = display_directory_structure(output_dir, files_and_dirs)
|
||
#
|
||
# if selected_file:
|
||
# if selected_file.lower().endswith(('.png', '.jpg', '.jpeg')):
|
||
# st.image(selected_file, caption=selected_file, use_column_width=True)
|
||
# else:
|
||
# with open(selected_file, 'r', encoding='utf-8') as f:
|
||
# st.text(f"内容:{selected_file}")
|
||
# st.text(f.read())
|
||
|
||
|