251 lines
9.7 KiB
Python
251 lines
9.7 KiB
Python
import argparse
|
||
import os
|
||
import zipfile
|
||
import glob
|
||
from typing import List
|
||
from pre import process_slide
|
||
from test_pl_unequal import predict
|
||
from typing import Dict
|
||
from typing import Tuple
|
||
from utils.labelme import save_pred_to_json
|
||
from plot import plot2
|
||
from predict import predict_and_plot
|
||
from train_pl import train, pre_process, create_save_path
|
||
import numpy as np
|
||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||
import sys
|
||
from pathlib import Path
|
||
import traceback
|
||
from dp.launching.typing import BaseModel, Field, Float, Optional
|
||
|
||
from dp.launching.typing import BaseModel, String, Set, Enum, List
|
||
from dp.launching.cli import to_runner, default_minimal_exception_handler
|
||
from dp.launching.typing.io import InputFilePath, OutputDirectory
|
||
|
||
from dp.launching.cli import (
|
||
SubParser,
|
||
default_minimal_exception_handler,
|
||
run_sp_and_exit,
|
||
to_runner,
|
||
)
|
||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
|
||
print(sys.path)
|
||
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot, gnn_generate_report
|
||
from train_pl_v2_aug import train as egnn_train
|
||
from model_type_dict import norm_line_sv_label, cz_label, xj_label, smov_label, model_path_dict
|
||
from plot_view import plot_multiple_annotations
|
||
|
||
class PredictOptions(BaseModel):
|
||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
|
||
output_dir: OutputDirectory = Field(
|
||
default="./result"
|
||
) # default will be override after online
|
||
|
||
class TrainOptions(BaseModel):
|
||
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
|
||
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
|
||
output_dir: OutputDirectory = Field(
|
||
default="./result"
|
||
)
|
||
|
||
class GnnPredictOptions(BaseModel):
|
||
edges_length: Float = Field(default=35.5, description="连接节点边长")
|
||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
|
||
gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
|
||
output_dir: OutputDirectory = Field(
|
||
default="./result"
|
||
) # default will be override after online
|
||
|
||
class GnnTrainOptions(BaseModel):
|
||
edges_length: Float = Field(default=35.5, description="连接节点边长")
|
||
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
|
||
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
|
||
output_dir: OutputDirectory = Field(
|
||
default="./result"
|
||
)
|
||
|
||
class SetOptional(String, Enum):
|
||
option_i_1 = norm_line_sv_label
|
||
option_i_2 = cz_label
|
||
option_i_3 = xj_label
|
||
option_i_4 = smov_label
|
||
|
||
|
||
class SingleWrapper(BaseModel):
|
||
model_name: SetOptional
|
||
#edges_length: Float = Field(description="连接节点边长")
|
||
|
||
class CombineGnnPredictOptions(BaseModel):
|
||
select_models: List[SingleWrapper] = Field(default=[SingleWrapper(model_name=norm_line_sv_label)],min_items=1, unique_items=True, description="选择的模型")
|
||
edges_length: Float = Field(default=35.5, description="连接节点边长")
|
||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||
#model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
|
||
#gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
|
||
output_dir: OutputDirectory = Field(
|
||
default="./result"
|
||
) # default will be override after online
|
||
|
||
|
||
|
||
|
||
class PredictGlobalOptions(PredictOptions, BaseModel):
|
||
...
|
||
|
||
class TrainGlobalOptions(TrainOptions, BaseModel):
|
||
...
|
||
|
||
class GnnPredictGlobalOptions(GnnPredictOptions, BaseModel):
|
||
...
|
||
|
||
class GnnTrainGlobalOptions(GnnTrainOptions, BaseModel):
|
||
...
|
||
|
||
class CombineGnnPredictGlobalOptions(CombineGnnPredictOptions, BaseModel):
|
||
...
|
||
|
||
def predict_app(opts: PredictGlobalOptions) -> int:
|
||
if opts.model_path is None:
|
||
print("未指定模型路径,使用默认模型")
|
||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||
opts.model_path = model_path
|
||
else:
|
||
model_path = opts.model_path
|
||
|
||
predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
|
||
return 0
|
||
|
||
def gnn_predict_app(opts: GnnPredictGlobalOptions) -> int:
|
||
if opts.gnn_model_path is None:
|
||
print("未指定gnn模型路径,使用默认模型")
|
||
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model/last.ckpt")
|
||
opts.gnn_model_path = gnn_model_path
|
||
else:
|
||
gnn_model_path = opts.gnn_model_path
|
||
|
||
if opts.model_path is None:
|
||
print("未指定模型路径,使用默认模型")
|
||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||
opts.model_path = model_path
|
||
else:
|
||
model_path = opts.model_path
|
||
|
||
save_path = predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
|
||
|
||
test_path = save_path["post_process"]
|
||
print("gnn 预测开始")
|
||
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length)
|
||
gnn_generate_report(save_path, opts.output_dir.get_full_path())
|
||
print("gnn 预测结束")
|
||
|
||
def train_app(opts: TrainGlobalOptions) -> int:
|
||
#创建目录
|
||
data_path = opts.train_path.get_full_path()
|
||
base_path = opts.output_dir.get_full_path()
|
||
save_path = create_save_path(base_path)
|
||
print("save path: ", save_path)
|
||
|
||
original_path = save_path['train_original']
|
||
train_pre_process_path = save_path['train_pre_process']
|
||
model_save_path = save_path['model_save_path']
|
||
|
||
# 解压数据集到original_path
|
||
with zipfile.ZipFile(data_path, 'r') as zip_ref:
|
||
zip_ref.extractall(original_path)
|
||
|
||
#image_path = './train_and_test'
|
||
|
||
pre_process(original_path, train_pre_process_path)
|
||
print("train start")
|
||
model_path = train(train_pre_process_path, model_save_path)
|
||
print("test start")
|
||
|
||
|
||
if opts.test_path is not None:
|
||
print("test start")
|
||
test_path = opts.test_path.get_full_path()
|
||
predict_and_plot(model_path, test_path, base_path)
|
||
print("test end")
|
||
else:
|
||
print("未指定测试数据集,跳过测试")
|
||
|
||
def gnn_train_app(opts: GnnTrainGlobalOptions) -> int:
|
||
train_path = opts.train_path.get_full_path()
|
||
base_path = opts.output_dir.get_full_path()
|
||
model_save_path = os.path.join(base_path, "train/model")
|
||
train_data_path = os.path.join(base_path, "train/dataset")
|
||
with zipfile.ZipFile(train_path, 'r') as zip_ref:
|
||
zip_ref.extractall(train_data_path)
|
||
|
||
print("gnn 训练开始")
|
||
model_name = egnn_train(train_data_path, model_save_path, opts.edges_length)
|
||
print("gnn 训练结束")
|
||
|
||
|
||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||
save_path = predict_and_plot(model_path, opts.test_path.get_full_path(), opts.output_dir.get_full_path())
|
||
|
||
test_path = save_path["post_process"]
|
||
print("gnn 预测开始")
|
||
egnn_predict_and_plot(model_name, test_path, opts.output_dir.get_full_path(), opts.edges_length)
|
||
gnn_generate_report(save_path, opts.output_dir.get_full_path())
|
||
print("gnn 预测结束")
|
||
|
||
def combine_gnn_app(opts: CombineGnnPredictGlobalOptions) -> int:
|
||
try:
|
||
print("select_models的类型是:", type(opts.select_models))
|
||
print("选择的模型:",opts.select_models)
|
||
for model in opts.select_models:
|
||
print("选择的模型的名称:",model.model_name.value)
|
||
|
||
|
||
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/model/last.ckpt")
|
||
opts.gnn_model_path = gnn_model_path
|
||
|
||
if opts.model_path is None:
|
||
print("未指定模型路径,使用默认模型")
|
||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||
opts.model_path = model_path
|
||
else:
|
||
model_path = opts.model_path
|
||
|
||
save_path = predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
|
||
|
||
test_path = save_path["post_process"]
|
||
print("gnn 预测开始")
|
||
|
||
for opt in opts.select_models:
|
||
relative_model_path = model_path_dict[opt.model_name.value]
|
||
gnn_model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../egnn_v2/", relative_model_path)
|
||
print("gnn model path :", gnn_model_path)
|
||
gnn_save_path = egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length, opt.model_name.value)
|
||
|
||
gnn_generate_report(gnn_save_path, opts.output_dir.get_full_path())
|
||
print("gnn 预测结束")
|
||
|
||
except Exception as e:
|
||
print(f"Error processing error: {e}")
|
||
traceback.print_exc()
|
||
return 1
|
||
|
||
def to_parser():
|
||
return {
|
||
"原子缺陷预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
|
||
"原子缺陷模型训练": SubParser(TrainGlobalOptions, train_app, "训练"),
|
||
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测"),
|
||
"combine_gnn预测": SubParser(CombineGnnPredictGlobalOptions, combine_gnn_app, "gnn预测"),
|
||
"gnn模型训练": SubParser(GnnTrainGlobalOptions, gnn_train_app, "训练"),
|
||
}
|
||
|
||
if __name__ == "__main__":
|
||
run_sp_and_exit(
|
||
to_parser(),
|
||
description="原子缺陷检测工具",
|
||
version="0.1.0",
|
||
exception_handler=default_minimal_exception_handler,
|
||
)
|
||
|
||
|
||
|