2024-07-24 16:10:08 +08:00
|
|
|
|
import streamlit as st
|
|
|
|
|
import os
|
|
|
|
|
import glob
|
|
|
|
|
import shutil
|
2024-07-24 17:28:37 +08:00
|
|
|
|
import argparse
|
|
|
|
|
import os
|
|
|
|
|
import zipfile
|
|
|
|
|
import glob
|
|
|
|
|
from typing import List
|
|
|
|
|
from pre import process_slide
|
|
|
|
|
from test_pl_unequal import predict
|
|
|
|
|
from typing import Dict
|
|
|
|
|
from typing import Tuple
|
|
|
|
|
from utils.labelme import save_pred_to_json
|
|
|
|
|
from plot import plot2
|
|
|
|
|
from predict import predict_and_plot
|
|
|
|
|
from train_pl import train, pre_process, create_save_path
|
|
|
|
|
import numpy as np
|
|
|
|
|
from dp.launching.report import Report, ReportSection, AutoReportElement
|
|
|
|
|
import sys
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
|
|
|
|
|
print(sys.path)
|
|
|
|
|
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
|
|
|
|
|
from train_pl_v2_aug import train as egnn_train
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件
|
2024-07-24 17:28:37 +08:00
|
|
|
|
def get_model_options(model_dir="../egnn_v2/model"):
|
2024-07-24 16:10:08 +08:00
|
|
|
|
ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt"))
|
|
|
|
|
model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files]
|
|
|
|
|
return model_options
|
|
|
|
|
|
2024-07-25 10:15:32 +08:00
|
|
|
|
def load_images_from_folder(folder_path):
|
|
|
|
|
images = []
|
|
|
|
|
for filename in os.listdir(folder_path):
|
|
|
|
|
if filename.endswith(('.png', '.jpg', '.jpeg')):
|
|
|
|
|
img_path = os.path.join(folder_path, filename)
|
|
|
|
|
img = Image.open(img_path)
|
|
|
|
|
images.append((filename, img))
|
|
|
|
|
return images
|
|
|
|
|
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# 初始化
|
|
|
|
|
if 'models' not in st.session_state:
|
|
|
|
|
st.session_state['models'] = []
|
|
|
|
|
|
|
|
|
|
# 页面标题
|
|
|
|
|
st.title("Bohrium Apps 开发者中心")
|
|
|
|
|
|
|
|
|
|
# 创建 Tabs
|
|
|
|
|
tabs = st.tabs(["combine_gnn预测", "新增模型"])
|
|
|
|
|
|
2024-07-24 17:28:37 +08:00
|
|
|
|
def run_task(data_file,select_models, edges_length):
|
|
|
|
|
# 打印调试信息
|
|
|
|
|
st.write(f"Edges Length: {edges_length}")
|
|
|
|
|
st.write(f"Selected Models: {select_models}")
|
|
|
|
|
st.write(f"data_file: {data_file.name}")
|
|
|
|
|
# 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径
|
|
|
|
|
base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
|
|
|
|
|
temp_dir = os.path.join(base_dir, "temp")
|
|
|
|
|
os.makedirs(temp_dir, exist_ok=True)
|
|
|
|
|
temp_file_path = os.path.join(temp_dir, data_file.name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(temp_file_path, "wb") as f:
|
|
|
|
|
f.write(data_file.getbuffer())
|
|
|
|
|
|
|
|
|
|
print("select_models的类型是:", type(select_models))
|
|
|
|
|
print("选择的模型:",select_models)
|
|
|
|
|
for model in select_models:
|
|
|
|
|
print("选择的模型的名称:",model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dir = "./streamlit_app_result"
|
|
|
|
|
model_path = os.path.join(base_dir, "model/last.ckpt")
|
|
|
|
|
save_path = predict_and_plot(model_path, temp_file_path, output_dir)
|
|
|
|
|
|
|
|
|
|
test_path = save_path["post_process"]
|
|
|
|
|
print("gnn 预测开始")
|
|
|
|
|
|
|
|
|
|
for opt in select_models:
|
|
|
|
|
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt")
|
|
|
|
|
print("gnn model path :", gnn_model_path)
|
|
|
|
|
egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length)
|
|
|
|
|
|
|
|
|
|
print("gnn 预测结束")
|
|
|
|
|
|
2024-07-25 10:15:32 +08:00
|
|
|
|
return {
|
|
|
|
|
"result_view": os.path.join(output_dir, "gnn_predict_result_view"),
|
|
|
|
|
"connect_view": os.path.join(output_dir, "gnn_predict_connect_view")
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# 用于展示结果的缓存
|
|
|
|
|
result_paths = {
|
|
|
|
|
"result_view": None,
|
|
|
|
|
"connect_view": None
|
|
|
|
|
}
|
|
|
|
|
|
2024-07-24 17:28:37 +08:00
|
|
|
|
|
2024-07-24 16:10:08 +08:00
|
|
|
|
with tabs[0]:
|
2024-07-24 17:28:37 +08:00
|
|
|
|
# st.header("combine_gnn预测")
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# Data Path
|
|
|
|
|
data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集")
|
|
|
|
|
|
|
|
|
|
# Model Path
|
2024-07-24 17:28:37 +08:00
|
|
|
|
#model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型")
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# Gnn Model Path
|
2024-07-24 17:28:37 +08:00
|
|
|
|
#gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型")
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# Select Models
|
2024-07-24 17:28:37 +08:00
|
|
|
|
# st.subheader("Select Models")
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
|
|
|
|
# 获取模型选项
|
|
|
|
|
model_options = get_model_options()
|
|
|
|
|
|
|
|
|
|
# 多选框
|
|
|
|
|
selected_models = st.multiselect("选择模型", model_options, default=st.session_state['models'])
|
|
|
|
|
st.session_state['models'] = selected_models
|
|
|
|
|
|
|
|
|
|
# Edges Length
|
|
|
|
|
edges_length = st.number_input("Edges Length", value=35.5)
|
|
|
|
|
|
|
|
|
|
# 运行任务按钮
|
|
|
|
|
if st.button("运行任务"):
|
2024-07-24 17:28:37 +08:00
|
|
|
|
if not data_file :
|
2024-07-24 16:10:08 +08:00
|
|
|
|
st.error("请上传所有必需的文件")
|
|
|
|
|
else:
|
|
|
|
|
# 模拟任务运行
|
|
|
|
|
st.success("任务运行中...")
|
2024-07-24 17:28:37 +08:00
|
|
|
|
# 获取选中的模型
|
|
|
|
|
select_models = st.session_state['models']
|
|
|
|
|
# 获取边长
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|
2024-07-24 17:28:37 +08:00
|
|
|
|
run_task(data_file, select_models, edges_length)
|
2024-07-24 16:10:08 +08:00
|
|
|
|
st.success("任务完成")
|
2024-07-25 10:15:32 +08:00
|
|
|
|
# 展示结果
|
|
|
|
|
if result_paths["result_view"]:
|
|
|
|
|
st.subheader("GNN 预测结果")
|
|
|
|
|
result_images = load_images_from_folder(result_paths["result_view"])
|
|
|
|
|
for filename, img in result_images:
|
|
|
|
|
st.image(img, caption=filename, use_column_width=True)
|
|
|
|
|
|
|
|
|
|
if result_paths["connect_view"]:
|
|
|
|
|
st.subheader("连接视图")
|
|
|
|
|
connect_view_images = load_images_from_folder(result_paths["connect_view"])
|
|
|
|
|
for filename, img in connect_view_images:
|
|
|
|
|
st.image(img, caption=filename, use_column_width=True)
|
2024-07-24 17:28:37 +08:00
|
|
|
|
|
2024-07-24 16:10:08 +08:00
|
|
|
|
with tabs[1]:
|
|
|
|
|
st.header("新增模型")
|
|
|
|
|
|
|
|
|
|
# 上传模型文件
|
|
|
|
|
new_model_file = st.file_uploader("Model Path", type=["ckpt"], help="上传新的模型文件 (.ckpt)")
|
|
|
|
|
|
|
|
|
|
if st.button("提交"):
|
|
|
|
|
if new_model_file is None:
|
|
|
|
|
st.error("请上传模型文件")
|
|
|
|
|
else:
|
2024-07-24 17:28:37 +08:00
|
|
|
|
model_dir = "../egnn_v2/model"
|
2024-07-24 16:10:08 +08:00
|
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
model_file_path = os.path.join(model_dir, new_model_file.name)
|
|
|
|
|
|
|
|
|
|
# 保存上传的模型文件
|
|
|
|
|
with open(model_file_path, "wb") as f:
|
|
|
|
|
f.write(new_model_file.getbuffer())
|
|
|
|
|
|
|
|
|
|
st.success(f"模型文件已保存到 {model_file_path}")
|
2024-07-24 17:28:37 +08:00
|
|
|
|
st.rerun() # 刷新页面以更新模型选项
|
|
|
|
|
|
2024-07-24 16:10:08 +08:00
|
|
|
|
|