streamlit_app动态更新和新增模型

This commit is contained in:
somunslotus 2024-07-24 17:28:37 +08:00
parent e871c3b8ed
commit c4a45a925f
1 changed files with 73 additions and 21 deletions

View File

@ -2,10 +2,32 @@ import streamlit as st
import os import os
import glob import glob
import shutil import shutil
import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件 # 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
def get_model_options(model_dir="model"): def get_model_options(model_dir="../egnn_v2/model"):
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt")) ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files] model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
return model_options return model_options
@ -21,20 +43,57 @@ st.title("Bohrium Apps 开发者中心")
# 创建 Tabs # 创建 Tabs
tabs = st.tabs(["combine_gnn预测", "新增模型"]) tabs = st.tabs(["combine_gnn预测", "新增模型"])
def run_task(data_file,select_models, edges_length):
# 打印调试信息
st.write(f"Edges Length: {edges_length}")
st.write(f"Selected Models: {select_models}")
st.write(f"data_file: {data_file.name}")
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
temp_dir = os.path.join(base_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, data_file.name)
with open(temp_file_path, "wb") as f:
f.write(data_file.getbuffer())
print("select_models的类型是", type(select_models))
print("选择的模型:",select_models)
for model in select_models:
print("选择的模型的名称:",model)
output_dir = "./streamlit_app_result"
model_path = os.path.join(base_dir, "model/last.ckpt")
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
test_path = save_path["post_process"]
print("gnn 预测开始")
for opt in select_models:
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
print("gnn model path :", gnn_model_path)
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
print("gnn 预测结束")
with tabs[0]: with tabs[0]:
st.header("combine_gnn预测") # st.header("combine_gnn预测")
# Data Path # Data Path
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集") data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
# Model Path # Model Path
model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型") #model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
# Gnn Model Path # Gnn Model Path
gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型") #gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
# Select Models # Select Models
st.subheader("Select Models") # st.subheader("Select Models")
# 获取模型选项 # 获取模型选项
model_options = get_model_options() model_options = get_model_options()
@ -48,25 +107,19 @@ with tabs[0]:
# 运行任务按钮 # 运行任务按钮
if st.button("运行任务"): if st.button("运行任务"):
if not data_file or not model_file or not gnn_model_file: if not data_file :
st.error("请上传所有必需的文件") st.error("请上传所有必需的文件")
else: else:
# 模拟任务运行 # 模拟任务运行
st.success("任务运行中...") st.success("任务运行中...")
# 获取选中的模型
select_models = st.session_state['models']
# 获取边长
# 生成配置 run_task(data_file, select_models, edges_length)
config = {
"edges_length": edges_length,
"output_dir": "../result",
"select_models": [{"model_name": model} for model in st.session_state['models']]
}
st.json(config)
# TODO: 这里可以加入你运行任务的代码,例如调用模型预测函数等
# run_task(data_file, model_file, gnn_model_file, config)
st.success("任务完成") st.success("任务完成")
with tabs[1]: with tabs[1]:
st.header("新增模型") st.header("新增模型")
@ -77,7 +130,7 @@ with tabs[1]:
if new_model_file is None: if new_model_file is None:
st.error("请上传模型文件") st.error("请上传模型文件")
else: else:
model_dir = "model" model_dir = "../egnn_v2/model"
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
model_file_path = os.path.join(model_dir, new_model_file.name) model_file_path = os.path.join(model_dir, new_model_file.name)
@ -86,7 +139,6 @@ with tabs[1]:
f.write(new_model_file.getbuffer()) f.write(new_model_file.getbuffer())
st.success(f"模型文件已保存到 {model_file_path}") st.success(f"模型文件已保存到 {model_file_path}")
st.experimental_rerun() # 刷新页面以更新模型选项 st.rerun() # 刷新页面以更新模型选项
# 运行 Streamlit 应用
# 保存为 app.py 后运行 `streamlit run app.py`