import streamlit as st import os import glob import shutil import argparse import os import zipfile import glob from typing import List from pre import process_slide from test_pl_unequal import predict from typing import Dict from typing import Tuple from utils.labelme import save_pred_to_json from plot import plot2 from predict import predict_and_plot from train_pl import train, pre_process, create_save_path import numpy as np from dp.launching.report import Report, ReportSection, AutoReportElement import sys from pathlib import Path sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2'))) print(sys.path) from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report from train_pl_v2_aug import train as egnn_train # 获取运行脚本目录中的 `model` 文件夹内的所有 `*.ckpt` 文件 def get_model_options(model_dir="../egnn_v2/model"): ckpt_files = glob.glob(os.path.join(model_dir, "*.ckpt")) model_options = [os.path.basename(f).replace(".ckpt", "") for f in ckpt_files] return model_options def load_images_from_folder(folder_path): images = [] for filename in os.listdir(folder_path): if filename.endswith(('.png', '.jpg', '.jpeg')): img_path = os.path.join(folder_path, filename) img = Image.open(img_path) images.append((filename, img)) return images # 初始化 if 'models' not in st.session_state: st.session_state['models'] = [] # 页面标题 st.title("Bohrium Apps 开发者中心") # 创建 Tabs tabs = st.tabs(["combine_gnn预测", "新增模型"]) def run_task(data_file,select_models, edges_length): # 打印调试信息 st.write(f"Edges Length: {edges_length}") st.write(f"Selected Models: {select_models}") st.write(f"data_file: {data_file.name}") # 保存zip文件到临时目录, 临时目录为当前目录下的 `temp` 文件夹,使用绝对路径 base_dir = os.path.join(os.path.dirname(os.path.abspath(__file__))) temp_dir = os.path.join(base_dir, "temp") os.makedirs(temp_dir, exist_ok=True) temp_file_path = os.path.join(temp_dir, data_file.name) with open(temp_file_path, "wb") as f: f.write(data_file.getbuffer()) print("select_models的类型是:", type(select_models)) print("选择的模型:",select_models) for model in select_models: print("选择的模型的名称:",model) output_dir = "./streamlit_app_result" model_path = os.path.join(base_dir, "model/last.ckpt") save_path = predict_and_plot(model_path, temp_file_path, output_dir) test_path = save_path["post_process"] print("gnn 预测开始") for opt in select_models: gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model", opt+".ckpt") print("gnn model path :", gnn_model_path) egnn_predict_and_plot(gnn_model_path, test_path, output_dir, edges_length) print("gnn 预测结束") return { "result_view": os.path.join(output_dir, "gnn_predict_result_view"), "connect_view": os.path.join(output_dir, "gnn_predict_connect_view") } # 用于展示结果的缓存 result_paths = { "result_view": None, "connect_view": None } with tabs[0]: # st.header("combine_gnn预测") # Data Path data_file = st.file_uploader("Data Path", type=["zip"], help="测试的缺陷数据集") # Model Path #model_file = st.file_uploader("Model Path", type=["ckpt"], help="原子缺陷处理模型") # Gnn Model Path #gnn_model_file = st.file_uploader("Gnn Model Path", type=["ckpt"], help="gnn模型") # Select Models # st.subheader("Select Models") # 获取模型选项 model_options = get_model_options() # 多选框 selected_models = st.multiselect("选择模型", model_options, default=st.session_state['models']) st.session_state['models'] = selected_models # Edges Length edges_length = st.number_input("Edges Length", value=35.5) # 运行任务按钮 if st.button("运行任务"): if not data_file : st.error("请上传所有必需的文件") else: # 模拟任务运行 st.success("任务运行中...") # 获取选中的模型 select_models = st.session_state['models'] # 获取边长 run_task(data_file, select_models, edges_length) st.success("任务完成") # 展示结果 if result_paths["result_view"]: st.subheader("GNN 预测结果") result_images = load_images_from_folder(result_paths["result_view"]) for filename, img in result_images: st.image(img, caption=filename, use_column_width=True) if result_paths["connect_view"]: st.subheader("连接视图") connect_view_images = load_images_from_folder(result_paths["connect_view"]) for filename, img in connect_view_images: st.image(img, caption=filename, use_column_width=True) with tabs[1]: st.header("新增模型") # 上传模型文件 new_model_file = st.file_uploader("Model Path", type=["ckpt"], help="上传新的模型文件 (.ckpt)") if st.button("提交"): if new_model_file is None: st.error("请上传模型文件") else: model_dir = "../egnn_v2/model" os.makedirs(model_dir, exist_ok=True) model_file_path = os.path.join(model_dir, new_model_file.name) # 保存上传的模型文件 with open(model_file_path, "wb") as f: f.write(new_model_file.getbuffer()) st.success(f"模型文件已保存到 {model_file_path}") st.rerun() # 刷新页面以更新模型选项