streamlit_app动态更新和新增模型
This commit is contained in:
parent
e871c3b8ed
commit
c4a45a925f
|
@ -2,10 +2,32 @@ import streamlit as st
|
|||
import os
|
||||
import glob
|
||||
import shutil
|
||||
import argparse
|
||||
import os
|
||||
import zipfile
|
||||
import glob
|
||||
from typing import List
|
||||
from pre import process_slide
|
||||
from test_pl_unequal import predict
|
||||
from typing import Dict
|
||||
from typing import Tuple
|
||||
from utils.labelme import save_pred_to_json
|
||||
from plot import plot2
|
||||
from predict import predict_and_plot
|
||||
from train_pl import train, pre_process, create_save_path
|
||||
import numpy as np
|
||||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
|
||||
print(sys.path)
|
||||
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
|
||||
from train_pl_v2_aug import train as egnn_train
|
||||
|
||||
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
|
||||
def get_model_options(model_dir="model"):
|
||||
def get_model_options(model_dir="../egnn_v2/model"):
|
||||
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
|
||||
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
|
||||
return model_options
|
||||
|
@ -21,20 +43,57 @@ st.title("Bohrium Apps 开发者中心")
|
|||
# 创建 Tabs
|
||||
tabs = st.tabs(["combine_gnn预测", "新增模型"])
|
||||
|
||||
def run_task(data_file,select_models, edges_length):
|
||||
# 打印调试信息
|
||||
st.write(f"Edges Length: {edges_length}")
|
||||
st.write(f"Selected Models: {select_models}")
|
||||
st.write(f"data_file: {data_file.name}")
|
||||
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
|
||||
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
temp_dir = os.path.join(base_dir, "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(temp_dir, data_file.name)
|
||||
|
||||
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(data_file.getbuffer())
|
||||
|
||||
print("select_models的类型是:", type(select_models))
|
||||
print("选择的模型:",select_models)
|
||||
for model in select_models:
|
||||
print("选择的模型的名称:",model)
|
||||
|
||||
|
||||
output_dir = "./streamlit_app_result"
|
||||
model_path = os.path.join(base_dir, "model/last.ckpt")
|
||||
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
|
||||
|
||||
test_path = save_path["post_process"]
|
||||
print("gnn 预测开始")
|
||||
|
||||
for opt in select_models:
|
||||
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
|
||||
print("gnn model path :", gnn_model_path)
|
||||
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
|
||||
|
||||
print("gnn 预测结束")
|
||||
|
||||
|
||||
with tabs[0]:
|
||||
st.header("combine_gnn预测")
|
||||
# st.header("combine_gnn预测")
|
||||
|
||||
# Data Path
|
||||
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
|
||||
|
||||
# Model Path
|
||||
model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
|
||||
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
|
||||
|
||||
# Gnn Model Path
|
||||
gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
|
||||
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
|
||||
|
||||
# Select Models
|
||||
st.subheader("Select Models")
|
||||
# st.subheader("Select Models")
|
||||
|
||||
# 获取模型选项
|
||||
model_options = get_model_options()
|
||||
|
@ -48,25 +107,19 @@ with tabs[0]:
|
|||
|
||||
# 运行任务按钮
|
||||
if st.button("运行任务"):
|
||||
if not data_file or not model_file or not gnn_model_file:
|
||||
if not data_file :
|
||||
st.error("请上传所有必需的文件")
|
||||
else:
|
||||
# 模拟任务运行
|
||||
st.success("任务运行中...")
|
||||
# 获取选中的模型
|
||||
select_models = st.session_state['models']
|
||||
# 获取边长
|
||||
|
||||
# 生成配置
|
||||
config = {
|
||||
"edges_length": edges_length,
|
||||
"output_dir": "../result",
|
||||
"select_models": [{"model_name": model} for model in st.session_state['models']]
|
||||
}
|
||||
st.json(config)
|
||||
|
||||
# TODO: 这里可以加入你运行任务的代码,例如调用模型预测函数等
|
||||
# run_task(data_file, model_file, gnn_model_file, config)
|
||||
|
||||
run_task(data_file, select_models, edges_length)
|
||||
st.success("任务完成")
|
||||
|
||||
|
||||
with tabs[1]:
|
||||
st.header("新增模型")
|
||||
|
||||
|
@ -77,7 +130,7 @@ with tabs[1]:
|
|||
if new_model_file is None:
|
||||
st.error("请上传模型文件")
|
||||
else:
|
||||
model_dir = "model"
|
||||
model_dir = "../egnn_v2/model"
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
model_file_path = os.path.join(model_dir, new_model_file.name)
|
||||
|
||||
|
@ -86,7 +139,6 @@ with tabs[1]:
|
|||
f.write(new_model_file.getbuffer())
|
||||
|
||||
st.success(f"模型文件已保存到 {model_file_path}")
|
||||
st.experimental_rerun() # 刷新页面以更新模型选项
|
||||
st.rerun() # 刷新页面以更新模型选项
|
||||
|
||||
|
||||
# 运行 Streamlit 应用
|
||||
# 保存为 app.py 后运行 `streamlit run app.py`
|
||||
|
|
Loading…
Reference in New Issue