176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
import streamlit as st
|
||
import os
|
||
import glob
|
||
import shutil
|
||
import argparse
|
||
import os
|
||
import zipfile
|
||
import glob
|
||
from typing import List
|
||
from pre import process_slide
|
||
from test_pl_unequal import predict
|
||
from typing import Dict
|
||
from typing import Tuple
|
||
from utils.labelme import save_pred_to_json
|
||
from plot import plot2
|
||
from predict import predict_and_plot
|
||
from train_pl import train, pre_process, create_save_path
|
||
import numpy as np
|
||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
|
||
print(sys.path)
|
||
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
|
||
from train_pl_v2_aug import train as egnn_train
|
||
|
||
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
|
||
def get_model_options(model_dir="../egnn_v2/model"):
|
||
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
|
||
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
|
||
return model_options
|
||
|
||
def load_images_from_folder(folder_path):
|
||
images = []
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith(('.png', '.jpg', '.jpeg')):
|
||
img_path = os.path.join(folder_path, filename)
|
||
img = Image.open(img_path)
|
||
images.append((filename, img))
|
||
return images
|
||
|
||
|
||
# 初始化
|
||
if 'models' not in st.session_state:
|
||
st.session_state['models'] = []
|
||
|
||
# 页面标题
|
||
st.title("Bohrium Apps 开发者中心")
|
||
|
||
# 创建 Tabs
|
||
tabs = st.tabs(["combine_gnn预测", "新增模型"])
|
||
|
||
def run_task(data_file,select_models, edges_length):
|
||
# 打印调试信息
|
||
st.write(f"Edges Length: {edges_length}")
|
||
st.write(f"Selected Models: {select_models}")
|
||
st.write(f"data_file: {data_file.name}")
|
||
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
|
||
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
||
|
||
temp_dir = os.path.join(base_dir, "temp")
|
||
os.makedirs(temp_dir, exist_ok=True)
|
||
temp_file_path = os.path.join(temp_dir, data_file.name)
|
||
|
||
|
||
with open(temp_file_path, "wb") as f:
|
||
f.write(data_file.getbuffer())
|
||
|
||
print("select_models的类型是:", type(select_models))
|
||
print("选择的模型:",select_models)
|
||
for model in select_models:
|
||
print("选择的模型的名称:",model)
|
||
|
||
|
||
output_dir = "./streamlit_app_result"
|
||
model_path = os.path.join(base_dir, "model/last.ckpt")
|
||
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
|
||
|
||
test_path = save_path["post_process"]
|
||
print("gnn 预测开始")
|
||
|
||
for opt in select_models:
|
||
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
|
||
print("gnn model path :", gnn_model_path)
|
||
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
|
||
|
||
print("gnn 预测结束")
|
||
|
||
return {
|
||
"result_view": os.path.join(output_dir, "gnn_predict_result_view"),
|
||
"connect_view": os.path.join(output_dir, "gnn_predict_connect_view")
|
||
}
|
||
|
||
# 用于展示结果的缓存
|
||
result_paths = {
|
||
"result_view": None,
|
||
"connect_view": None
|
||
}
|
||
|
||
|
||
with tabs[0]:
|
||
# st.header("combine_gnn预测")
|
||
|
||
# Data Path
|
||
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
|
||
|
||
# Model Path
|
||
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
|
||
|
||
# Gnn Model Path
|
||
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
|
||
|
||
# Select Models
|
||
# st.subheader("Select Models")
|
||
|
||
# 获取模型选项
|
||
model_options = get_model_options()
|
||
|
||
# 多选框
|
||
selected_models = st.multiselect("选择模型", model_options, default=st.session_state['models'])
|
||
st.session_state['models'] = selected_models
|
||
|
||
# Edges Length
|
||
edges_length = st.number_input("Edges Length", value=35.5)
|
||
|
||
# 运行任务按钮
|
||
if st.button("运行任务"):
|
||
if not data_file :
|
||
st.error("请上传所有必需的文件")
|
||
else:
|
||
# 模拟任务运行
|
||
st.success("任务运行中...")
|
||
# 获取选中的模型
|
||
select_models = st.session_state['models']
|
||
# 获取边长
|
||
|
||
run_task(data_file, select_models, edges_length)
|
||
st.success("任务完成")
|
||
# 展示结果
|
||
if result_paths["result_view"]:
|
||
st.subheader("GNN 预测结果")
|
||
result_images = load_images_from_folder(result_paths["result_view"])
|
||
for filename, img in result_images:
|
||
st.image(img, caption=filename, use_column_width=True)
|
||
|
||
if result_paths["connect_view"]:
|
||
st.subheader("连接视图")
|
||
connect_view_images = load_images_from_folder(result_paths["connect_view"])
|
||
for filename, img in connect_view_images:
|
||
st.image(img, caption=filename, use_column_width=True)
|
||
|
||
with tabs[1]:
|
||
st.header("新增模型")
|
||
|
||
# 上传模型文件
|
||
new_model_file = st.file_uploader("Model Path", type=["ckpt"], help="上传新的模型文件 (.ckpt)")
|
||
|
||
if st.button("提交"):
|
||
if new_model_file is None:
|
||
st.error("请上传模型文件")
|
||
else:
|
||
model_dir = "../egnn_v2/model"
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
model_file_path = os.path.join(model_dir, new_model_file.name)
|
||
|
||
# 保存上传的模型文件
|
||
with open(model_file_path, "wb") as f:
|
||
f.write(new_model_file.getbuffer())
|
||
|
||
st.success(f"模型文件已保存到 {model_file_path}")
|
||
st.rerun() # 刷新页面以更新模型选项
|
||
|
||
|