streamlit_app动态更新和新增模型

This commit is contained in:
somunslotus 2024-07-24 17:28:37 +08:00
parent e871c3b8ed
commit c4a45a925f
1 changed files with 73 additions and 21 deletions

View File

@ -2,10 +2,32 @@ import streamlit as st
import os
import glob
import shutil
import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
def get_model_options(model_dir="model"):
def get_model_options(model_dir="../egnn_v2/model"):
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
return model_options
@ -21,20 +43,57 @@ st.title("Bohrium Apps 开发者中心")
# 创建 Tabs
tabs = st.tabs(["combine_gnn预测", "新增模型"])
def run_task(data_file,select_models, edges_length):
# 打印调试信息
st.write(f"Edges Length: {edges_length}")
st.write(f"Selected Models: {select_models}")
st.write(f"data_file: {data_file.name}")
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
temp_dir = os.path.join(base_dir, "temp")
os.makedirs(temp_dir, exist_ok=True)
temp_file_path = os.path.join(temp_dir, data_file.name)
with open(temp_file_path, "wb") as f:
f.write(data_file.getbuffer())
print("select_models的类型是", type(select_models))
print("选择的模型:",select_models)
for model in select_models:
print("选择的模型的名称:",model)
output_dir = "./streamlit_app_result"
model_path = os.path.join(base_dir, "model/last.ckpt")
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
test_path = save_path["post_process"]
print("gnn 预测开始")
for opt in select_models:
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
print("gnn model path :", gnn_model_path)
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
print("gnn 预测结束")
with tabs[0]:
st.header("combine_gnn预测")
# st.header("combine_gnn预测")
# Data Path
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
# Model Path
model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
# Gnn Model Path
gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
# Select Models
st.subheader("Select Models")
# st.subheader("Select Models")
# 获取模型选项
model_options = get_model_options()
@ -48,25 +107,19 @@ with tabs[0]:
# 运行任务按钮
if st.button("运行任务"):
if not data_file or not model_file or not gnn_model_file:
if not data_file :
st.error("请上传所有必需的文件")
else:
# 模拟任务运行
st.success("任务运行中...")
# 获取选中的模型
select_models = st.session_state['models']
# 获取边长
# 生成配置
config = {
"edges_length": edges_length,
"output_dir": "../result",
"select_models": [{"model_name": model} for model in st.session_state['models']]
}
st.json(config)
# TODO: 这里可以加入你运行任务的代码,例如调用模型预测函数等
# run_task(data_file, model_file, gnn_model_file, config)
run_task(data_file, select_models, edges_length)
st.success("任务完成")
with tabs[1]:
st.header("新增模型")
@ -77,7 +130,7 @@ with tabs[1]:
if new_model_file is None:
st.error("请上传模型文件")
else:
model_dir = "model"
model_dir = "../egnn_v2/model"
os.makedirs(model_dir, exist_ok=True)
model_file_path = os.path.join(model_dir, new_model_file.name)
@ -86,7 +139,6 @@ with tabs[1]:
f.write(new_model_file.getbuffer())
st.success(f"模型文件已保存到 {model_file_path}")
st.experimental_rerun() # 刷新页面以更新模型选项
st.rerun() # 刷新页面以更新模型选项
# 运行 Streamlit 应用
# 保存为 app.py 后运行 `streamlit run app.py`