atom-predict/msunet/streamlit_app.py

176 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import os
import glob
import shutil
import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
def get_model_options(model_dir="../egnn_v2/model"):
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
return model_options
def load_images_from_folder(folder_path):
images = []
for filename in os.listdir(folder_path):
if filename.endswith(('.png', '.jpg', '.jpeg')):
img_path = os.path.join(folder_path, filename)
img = Image.open(img_path)
images.append((filename, img))
return images
# 初始化
if 'models' not in st.session_state:
st.session_state['models'] = []
# 页面标题
st.title("Bohrium Apps 开发者中心")
# 创建 Tabs
tabs = st.tabs(["combine_gnn预测", "新增模型"])
def run_task(data_file,select_models, edges_length):
# 打印调试信息
st.write(f"Edges Length: {edges_length}")
st.write(f"Selected Models: {select_models}")
st.write(f"data_file: {data_file.name}")
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
temp_dir = os.path.join(base_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, data_file.name)
with open(temp_file_path, "wb") as f:
f.write(data_file.getbuffer())
print("select_models的类型是", type(select_models))
print("选择的模型:",select_models)
for model in select_models:
print("选择的模型的名称:",model)
output_dir = "./streamlit_app_result"
model_path = os.path.join(base_dir, "model/last.ckpt")
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
test_path = save_path["post_process"]
print("gnn 预测开始")
for opt in select_models:
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
print("gnn model path :", gnn_model_path)
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
print("gnn 预测结束")
return {
"result_view": os.path.join(output_dir, "gnn_predict_result_view"),
"connect_view": os.path.join(output_dir, "gnn_predict_connect_view")
}
# 用于展示结果的缓存
result_paths = {
"result_view": None,
"connect_view": None
}
with tabs[0]:
# st.header("combine_gnn预测")
# Data Path
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
# Model Path
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
# Gnn Model Path
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
# Select Models
# st.subheader("Select Models")
# 获取模型选项
model_options = get_model_options()
# 多选框
selected_models = st.multiselect("选择模型", model_options, default=st.session_state['models'])
st.session_state['models'] = selected_models
# Edges Length
edges_length = st.number_input("Edges Length", value=35.5)
# 运行任务按钮
if st.button("运行任务"):
if not data_file :
st.error("请上传所有必需的文件")
else:
# 模拟任务运行
st.success("任务运行中...")
# 获取选中的模型
select_models = st.session_state['models']
# 获取边长
run_task(data_file, select_models, edges_length)
st.success("任务完成")
# 展示结果
if result_paths["result_view"]:
st.subheader("GNN 预测结果")
result_images = load_images_from_folder(result_paths["result_view"])
for filename, img in result_images:
st.image(img, caption=filename, use_column_width=True)
if result_paths["connect_view"]:
st.subheader("连接视图")
connect_view_images = load_images_from_folder(result_paths["connect_view"])
for filename, img in connect_view_images:
st.image(img, caption=filename, use_column_width=True)
with tabs[1]:
st.header("新增模型")
# 上传模型文件
new_model_file = st.file_uploader("Model Path", type=["ckpt"], help="上传新的模型文件 (.ckpt)")
if st.button("提交"):
if new_model_file is None:
st.error("请上传模型文件")
else:
model_dir = "../egnn_v2/model"
os.makedirs(model_dir, exist_ok=True)
model_file_path = os.path.join(model_dir, new_model_file.name)
# 保存上传的模型文件
with open(model_file_path, "wb") as f:
f.write(new_model_file.getbuffer())
st.success(f"模型文件已保存到 {model_file_path}")
st.rerun() # 刷新页面以更新模型选项