atom-predict/msunet/app.py

112 lines
3.3 KiB
Python

import argparse
import os
import zipfile
import glob
from typing import List
from pre import process_slide
from test_pl_unequal import predict
from typing import Dict
from typing import Tuple
from utils.labelme import save_pred_to_json
from plot import plot2
from predict import predict_and_plot
from train_pl import train, pre_process, create_save_path
import numpy as np
from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
from dp.launching.typing import BaseModel, Field, Int,Optional
from dp.launching.cli import to_runner, default_minimal_exception_handler
from dp.launching.typing.io import InputFilePath, OutputDirectory
from dp.launching.cli import (
SubParser,
default_minimal_exception_handler,
run_sp_and_exit,
to_runner,
)
class PredictOptions(BaseModel):
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
) # default will be override after online
class TrainOptions(BaseModel):
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
test_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
)
class PredictGlobalOptions(PredictOptions, BaseModel):
...
class TrainGlobalOptions(TrainOptions, BaseModel):
...
def predict_app(opts: PredictGlobalOptions) -> int:
if opts.model_path is None:
print("未指定模型路径,使用默认模型")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
opts.model_path = model_path
else:
model_path = opts.model_path
predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
return 0
def train_app(opts: TrainGlobalOptions) -> int:
#创建目录
data_path = opts.train_path.get_full_path()
base_path = opts.output_dir.get_full_path()
save_path = create_save_path(base_path)
print("save path: ", save_path)
original_path = save_path['train_original']
train_pre_process_path = save_path['train_pre_process']
model_save_path = save_path['model_save_path']
# 解压数据集到original_path
with zipfile.ZipFile(data_path, 'r') as zip_ref:
zip_ref.extractall(original_path)
#image_path = './train_and_test'
pre_process(original_path, train_pre_process_path)
print("train start")
model_path = train(train_pre_process_path, model_save_path)
print("test start")
if opts.test_path is not None:
print("test start")
test_path = opts.test_path.get_full_path()
predict_and_plot(model_path, test_path, base_path)
print("test end")
else:
print("未指定测试数据集,跳过测试")
def to_parser():
return {
"预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
"训练": SubParser(TrainGlobalOptions, train_app, "训练"),
}
if __name__ == "__main__":
run_sp_and_exit(
to_parser(),
description="原子缺陷检测工具",
version="0.1.0",
exception_handler=default_minimal_exception_handler,
)