atom-predict/msunet/streamlit_app.py

176 lines
5.8 KiB
Python
Raw Normal View History

2024-07-24 16:10:08 +08:00
import streamlit as st
import os
import glob
import shutil
import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
2024-07-24 16:10:08 +08:00
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
def get_model_options(model_dir="../egnn_v2/model"):
2024-07-24 16:10:08 +08:00
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
return model_options
2024-07-25 10:15:32 +08:00
def load_images_from_folder(folder_path):
images = []
for filename in os.listdir(folder_path):
if filename.endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder_path, filename)
img = Image.open(img_path)
images.append((filename, img))
return images
2024-07-24 16:10:08 +08:00
# 初始化
if 'models' not in st.session_state:
st.session_state['models'] = []
# 页面标题
st.title("Bohrium Apps 开发者中心")
# 创建 Tabs
tabs = st.tabs(["combine_gnn预测", "新增模型"])
def run_task(data_file,select_models, edges_length):
# 打印调试信息
st.write(f"Edges Length: {edges_length}")
st.write(f"Selected Models: {select_models}")
st.write(f"data_file: {data_file.name}")
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
temp_dir = os.path.join(base_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, data_file.name)
with open(temp_file_path, "wb") as f:
f.write(data_file.getbuffer())
print("select_models的类型是", type(select_models))
print("选择的模型:",select_models)
for model in select_models:
print("选择的模型的名称:",model)
output_dir = "./streamlit_app_result"
model_path = os.path.join(base_dir, "model/last.ckpt")
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
test_path = save_path["post_process"]
print("gnn 预测开始")
for opt in select_models:
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
print("gnn model path :", gnn_model_path)
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
print("gnn 预测结束")
2024-07-25 10:15:32 +08:00
return {
"result_view": os.path.join(output_dir, "gnn_predict_result_view"),
"connect_view": os.path.join(output_dir, "gnn_predict_connect_view")
}
# 用于展示结果的缓存
result_paths = {
"result_view": None,
"connect_view": None
}
2024-07-24 16:10:08 +08:00
with tabs[0]:
# st.header("combine_gnn预测")
2024-07-24 16:10:08 +08:00
# Data Path
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
# Model Path
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
2024-07-24 16:10:08 +08:00
# Gnn Model Path
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
2024-07-24 16:10:08 +08:00
# Select Models
# st.subheader("Select Models")
2024-07-24 16:10:08 +08:00
# 获取模型选项
model_options = get_model_options()
# 多选框
selected_models = st.multiselect("选择模型", model_options, default=st.session_state['models'])
st.session_state['models'] = selected_models
# Edges Length
edges_length = st.number_input("Edges Length", value=35.5)
# 运行任务按钮
if st.button("运行任务"):
if not data_file :
2024-07-24 16:10:08 +08:00
st.error("请上传所有必需的文件")
else:
# 模拟任务运行
st.success("任务运行中...")
# 获取选中的模型
select_models = st.session_state['models']
# 获取边长
2024-07-24 16:10:08 +08:00
run_task(data_file, select_models, edges_length)
2024-07-24 16:10:08 +08:00
st.success("任务完成")
2024-07-25 10:15:32 +08:00
# 展示结果
if result_paths["result_view"]:
st.subheader("GNN 预测结果")
result_images = load_images_from_folder(result_paths["result_view"])
for filename, img in result_images:
st.image(img, caption=filename, use_column_width=True)
if result_paths["connect_view"]:
st.subheader("连接视图")
connect_view_images = load_images_from_folder(result_paths["connect_view"])
for filename, img in connect_view_images:
st.image(img, caption=filename, use_column_width=True)
2024-07-24 16:10:08 +08:00
with tabs[1]:
st.header("新增模型")
# 上传模型文件
new_model_file = st.file_uploader("Model Path", type=["ckpt"], help="上传新的模型文件 (.ckpt)")
if st.button("提交"):
if new_model_file is None:
st.error("请上传模型文件")
else:
model_dir = "../egnn_v2/model"
2024-07-24 16:10:08 +08:00
os.makedirs(model_dir, exist_ok=True)
model_file_path = os.path.join(model_dir, new_model_file.name)
# 保存上传的模型文件
with open(model_file_path, "wb") as f:
f.write(new_model_file.getbuffer())
st.success(f"模型文件已保存到 {model_file_path}")
st.rerun() # 刷新页面以更新模型选项
2024-07-24 16:10:08 +08:00