atom-predict/msunet/predict_streamlit.py

211 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import json
import os
import zipfile
import glob
from typing import List, Dict
from pre import process_slide # Assuming this is a custom module
from test_pl_unequal import predict # Assuming this is a custom module
from utils.labelme import save_pred_to_json # Assuming this is a custom module
from plot import plot2 # Assuming this is a custom module
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement # Assuming this is a custom module
from pathlib import Path
import streamlit as st
def zip_file(data_path: str, save_path: str) -> List[str]:
if not os.path.exists(save_path):
os.makedirs(save_path)
with zipfile.ZipFile(data_path, 'r') as zip_ref:
zip_ref.extractall(save_path)
image_paths = glob.glob(os.path.join(save_path, '**', '*.jpg'), recursive=True)
image_paths += glob.glob(os.path.join(save_path, '**', '*.png'), recursive=True)
return image_paths
def get_cmd_path() -> str:
script_path = os.path.abspath(__file__)
script_dir = os.path.dirname(script_path)
return script_dir
def create_save_path(base_path: str) -> Dict[str, str]:
paths = {
"dataset": os.path.join(base_path, "predict_dataset"),
"predict": os.path.join(base_path, "predict_result"),
"post_process": os.path.join(base_path, "predict_post_process"),
"plot_view": os.path.join(base_path, "predict_result_view"),
}
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
def find_image_json_pairs(directory: str) -> List[Dict[str, str]]:
pairs = []
files = os.listdir(directory)
file_dict: Dict[str, bool] = {}
for filename in files:
name, ext = os.path.splitext(filename)
if ext.lower() == '.jpg':
json_file = os.path.join(directory, name + '.json')
if os.path.exists(json_file) and os.path.isfile(json_file):
jpg_path = os.path.abspath(os.path.join(directory, name + '.jpg'))
json_path = os.path.abspath(json_file)
pairs.append({'jpg': jpg_path, 'json': json_path})
file_dict[name + '.json'] = True
elif ext.lower() == '.json' and filename not in file_dict:
jpg_file = os.path.join(directory, name + '.jpg')
if os.path.exists(jpg_file) and os.path.isfile(jpg_file):
jpg_path = os.path.abspath(jpg_file)
json_path = os.path.abspath(os.path.join(directory, name + '.json'))
pairs.append({'jpg': jpg_path, 'json': json_path})
return pairs
def predict_and_plot(model_path: str, img_paths: str, output_dir: str, progress: st.progress,
is_after_train: bool = False) -> None:
save_path = create_save_path(output_dir)
if not is_after_train:
image_paths = zip_file(img_paths, save_path["dataset"])
else:
image_paths = glob.glob(img_paths + "/**/*.jpg", recursive=True)
progress.progress(20)
# for path in image_paths:
# process_slide(path, save_path["pre_process"])
filename = predict(model_path, image_paths, save_path["predict"])
save_pred_to_json(filename, save_path["post_process"])
progress.progress(60)
pairs = find_image_json_pairs(save_path["post_process"])
for pair in pairs:
plot2(pair['jpg'], pair['json'], save_path["plot_view"])
progress.progress(80)
generate_report(save_path, output_dir)
progress.progress(100)
def remove_prefix(full_path: str, prefix: str) -> str:
return os.path.relpath(full_path, prefix)
def generate_report(save_path: Dict[str, str], output_dir: str) -> None:
img_elements = []
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
img_elements.append(AutoReportElement(
path=remove_prefix(img_path, output_dir),
title=img_path.split("/")[-1],
description='原始图片',
))
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
img_elements = []
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
img_elements.append(AutoReportElement(
path=remove_prefix(img_path, output_dir),
title=img_path.split('/')[-1],
description='预测结果',
))
post_process_img_section = ReportSection(title="预测结果", ncols=3, elements=img_elements)
report = Report(title="原子定位结果", sections=[post_process_img_section, ori_img_section])
report.save(output_dir)
# Streamlit App
st.title("原子位置缺陷预测")
st.write("上传一个包含图片的zip文件进行预测。")
#batch_size = st.number_input("批量预测大小", value=128, min_value=128,max_value=8192)
# 上传数据文件设置文件大小为500MB
data_path = st.file_uploader("上传数据文件 (.zip)", type=["zip"])
# mo_path 为可选文件默认为项目根目录下model/last.ckpt
#mo_path = st.file_uploader("上传模型文件 (.ckpt)", type=["ckpt"], help="模型文件应当位于项目根目录下文件名应当为last.ckpt")
#mo_path = st.file_uploader("上传模型文件 (.ckpt)", type=["ckpt"], help="模型文件应当位于项目根目录下文件名应当为last.ckpt", )
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "result")
def list_files_and_dirs_in_directory(directory: str) -> Dict[str, List[str]]:
file_dict = {}
for root, dirs, files in os.walk(directory):
relative_root = os.path.relpath(root, directory)
file_dict[relative_root] = files
return file_dict
def display_directory_structure(base_dir: str, file_dict: Dict[str, List[str]]):
st.title("目录结构")
for dir_path, files in file_dict.items():
if dir_path == ".":
st.subheader("根目录")
else:
st.subheader(dir_path)
for file in files:
full_path = os.path.join(base_dir, dir_path, file)
if st.button(file, key=full_path):
return full_path
return None
# def display_directory_structure(base_dir: str, file_dict: Dict[str, List[str]]):
# st.sidebar.title("目录结构")
# for dir_path, files in file_dict.items():
# if dir_path == ".":
# st.sidebar.subheader("根目录")
# else:
# st.sidebar.subheader(dir_path)
# for file in files:
# full_path = os.path.join(base_dir, dir_path, file)
# if st.sidebar.button(file, key=full_path):
# return full_path
# return None
#
def list_files(root_dir):
for root, dirs, files in os.walk(root_dir):
is_result = False
is_ori = False
if os.path.basename(root) == 'predict_result_view':
st.header('预测结果图片')
is_result = True
elif os.path.basename(root) == 'train_original':
st.header('原始图片')
is_ori = True
if is_result or is_ori:
for f in files:
if f.endswith('.jpg') or f.endswith('.png'):
file_path = os.path.join(root, f)
st.image(file_path, caption=f)
if data_path and st.button("运行预测"):
with open("temp.zip", "wb") as f:
f.write(data_path.getbuffer())
progress = st.progress(0)
with st.spinner("预测处理中,请稍候..."):
predict_and_plot(model_path, "temp.zip", output_dir, progress)
st.success("预测完成。以下是结果文件预览:")
list_files(output_dir)
# # 展示结果文件和图片
# files_and_dirs = list_files_and_dirs_in_directory(output_dir)
# selected_file = display_directory_structure(output_dir, files_and_dirs)
#
# if selected_file:
# if selected_file.lower().endswith(('.png', '.jpg', '.jpeg')):
# st.image(selected_file, caption=selected_file, use_column_width=True)
# else:
# with open(selected_file, 'r', encoding='utf-8') as f:
# st.text(f"内容:{selected_file}")
# st.text(f.read())