atom-predict/msunet/app.py

251 lines
9.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
import traceback
from dp.launching.typing import BaseModel, Field, Float, Optional
from dp.launching.typing import BaseModel, String, Set, Enum, List
from dp.launching.cli import to_runner, default_minimal_exception_handler
from dp.launching.typing.io import InputFilePath, OutputDirectory
from dp.launching.cli import (
SubParser,
default_minimal_exception_handler,
run_sp_and_exit,
to_runner,
)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
from train_pl_v2_aug import train as egnn_train
from model_type_dict import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
from plot_view import plot_multiple_annotations
class PredictOptions(BaseModel):
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
) # default will be override after online
class TrainOptions(BaseModel):
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
)
class GnnPredictOptions(BaseModel):
edges_length: Float = Field(default=35.5, description="连接节点边长")
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
output_dir: OutputDirectory = Field(
default="./result"
) # default will be override after online
class GnnTrainOptions(BaseModel):
edges_length: Float = Field(default=35.5, description="连接节点边长")
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
)
class SetOptional(String, Enum):
option_i_1 = norm_line_sv_label
option_i_2 = cz_label
option_i_3 = xj_label
option_i_4 = smov_label
class SingleWrapper(BaseModel):
model_name: SetOptional
#edges_length: Float = Field(description="连接节点边长")
class CombineGnnPredictOptions(BaseModel):
select_models: List[SingleWrapper] = Field(default=[SingleWrapper(model_name=norm_line_sv_label)],min_items=1, unique_items=True, description="选择的模型")
edges_length: Float = Field(default=35.5, description="连接节点边长")
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
#model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
#gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
output_dir: OutputDirectory = Field(
default="./result"
) # default will be override after online
class PredictGlobalOptions(PredictOptions, BaseModel):
...
class TrainGlobalOptions(TrainOptions, BaseModel):
...
class GnnPredictGlobalOptions(GnnPredictOptions, BaseModel):
...
class GnnTrainGlobalOptions(GnnTrainOptions, BaseModel):
...
class CombineGnnPredictGlobalOptions(CombineGnnPredictOptions, BaseModel):
...
def predict_app(opts: PredictGlobalOptions) -> int:
if opts.model_path is None:
print("未指定模型路径,使用默认模型")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
opts.model_path = model_path
else:
model_path = opts.model_path
predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
return 0
def gnn_predict_app(opts: GnnPredictGlobalOptions) -> int:
if opts.gnn_model_path is None:
print("未指定gnn模型路径使用默认模型")
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model/last.ckpt")
opts.gnn_model_path = gnn_model_path
else:
gnn_model_path = opts.gnn_model_path
if opts.model_path is None:
print("未指定模型路径,使用默认模型")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
opts.model_path = model_path
else:
model_path = opts.model_path
save_path = predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
test_path = save_path["post_process"]
print("gnn 预测开始")
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length)
gnn_generate_report(save_path, opts.output_dir.get_full_path())
print("gnn 预测结束")
def train_app(opts: TrainGlobalOptions) -> int:
#创建目录
data_path = opts.train_path.get_full_path()
base_path = opts.output_dir.get_full_path()
save_path = create_save_path(base_path)
print("save path: ", save_path)
original_path = save_path['train_original']
train_pre_process_path = save_path['train_pre_process']
model_save_path = save_path['model_save_path']
# 解压数据集到original_path
with zipfile.ZipFile(data_path, 'r') as zip_ref:
zip_ref.extractall(original_path)
#image_path = './train_and_test'
pre_process(original_path, train_pre_process_path)
print("train start")
model_path = train(train_pre_process_path, model_save_path)
print("test start")
if opts.test_path is not None:
print("test start")
test_path = opts.test_path.get_full_path()
predict_and_plot(model_path, test_path, base_path)
print("test end")
else:
print("未指定测试数据集,跳过测试")
def gnn_train_app(opts: GnnTrainGlobalOptions) -> int:
train_path = opts.train_path.get_full_path()
base_path = opts.output_dir.get_full_path()
model_save_path = os.path.join(base_path, "train/model")
train_data_path = os.path.join(base_path, "train/dataset")
with zipfile.ZipFile(train_path, 'r') as zip_ref:
zip_ref.extractall(train_data_path)
print("gnn 训练开始")
model_name = egnn_train(train_data_path, model_save_path, opts.edges_length)
print("gnn 训练结束")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
save_path = predict_and_plot(model_path, opts.test_path.get_full_path(), opts.output_dir.get_full_path())
test_path = save_path["post_process"]
print("gnn 预测开始")
egnn_predict_and_plot(model_name, test_path, opts.output_dir.get_full_path(), opts.edges_length)
gnn_generate_report(save_path, opts.output_dir.get_full_path())
print("gnn 预测结束")
def combine_gnn_app(opts: CombineGnnPredictGlobalOptions) -> int:
try:
print("select_models的类型是", type(opts.select_models))
print("选择的模型:",opts.select_models)
for model in opts.select_models:
print("选择的模型的名称:",model.model_name.value)
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model/last.ckpt")
opts.gnn_model_path = gnn_model_path
if opts.model_path is None:
print("未指定模型路径,使用默认模型")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
opts.model_path = model_path
else:
model_path = opts.model_path
save_path = predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
test_path = save_path["post_process"]
print("gnn 预测开始")
for opt in opts.select_models:
relative_model_path = model_path_dict[opt.model_name.value]
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/", relative_model_path)
print("gnn model path :", gnn_model_path)
gnn_save_path = egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length, opt.model_name.value)
gnn_generate_report(gnn_save_path, opts.output_dir.get_full_path())
print("gnn 预测结束")
except Exception as e:
print(f"Error processing error: {e}")
traceback.print_exc()
return 1
def to_parser():
return {
"原子缺陷预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
"原子缺陷模型训练": SubParser(TrainGlobalOptions, train_app, "训练"),
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测"),
"combine_gnn预测": SubParser(CombineGnnPredictGlobalOptions, combine_gnn_app, "gnn预测"),
"gnn模型训练": SubParser(GnnTrainGlobalOptions, gnn_train_app, "训练"),
}
if __name__ == "__main__":
run_sp_and_exit(
to_parser(),
description="原子缺陷检测工具",
version="0.1.0",
exception_handler=default_minimal_exception_handler,
)