112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
import argparse
|
|
import os
|
|
import zipfile
|
|
import glob
|
|
from typing import List
|
|
from pre import process_slide
|
|
from test_pl_unequal import predict
|
|
from typing import Dict
|
|
from typing import Tuple
|
|
from utils.labelme import save_pred_to_json
|
|
from plot import plot2
|
|
from predict import predict_and_plot
|
|
from train_pl import train, pre_process, create_save_path
|
|
import numpy as np
|
|
from dp.launching.report import Report, ReportSection, AutoReportElement
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
from dp.launching.typing import BaseModel, Field, Int,Optional
|
|
from dp.launching.cli import to_runner, default_minimal_exception_handler
|
|
from dp.launching.typing.io import InputFilePath, OutputDirectory
|
|
|
|
from dp.launching.cli import (
|
|
SubParser,
|
|
default_minimal_exception_handler,
|
|
run_sp_and_exit,
|
|
to_runner,
|
|
)
|
|
|
|
class PredictOptions(BaseModel):
|
|
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
|
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
|
|
output_dir: OutputDirectory = Field(
|
|
default="./result"
|
|
) # default will be override after online
|
|
|
|
class TrainOptions(BaseModel):
|
|
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
|
|
test_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="测试使用的模型")
|
|
output_dir: OutputDirectory = Field(
|
|
default="./result"
|
|
)
|
|
|
|
|
|
|
|
class PredictGlobalOptions(PredictOptions, BaseModel):
|
|
...
|
|
|
|
class TrainGlobalOptions(TrainOptions, BaseModel):
|
|
...
|
|
|
|
def predict_app(opts: PredictGlobalOptions) -> int:
|
|
if opts.model_path is None:
|
|
print("未指定模型路径,使用默认模型")
|
|
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
|
opts.model_path = model_path
|
|
else:
|
|
model_path = opts.model_path
|
|
|
|
predict_and_plot(model_path, opts.data_path.get_full_path(), opts.output_dir.get_full_path())
|
|
return 0
|
|
|
|
def train_app(opts: TrainGlobalOptions) -> int:
|
|
#创建目录
|
|
data_path = opts.train_path.get_full_path()
|
|
base_path = opts.output_dir.get_full_path()
|
|
save_path = create_save_path(base_path)
|
|
print("save path: ", save_path)
|
|
|
|
original_path = save_path['train_original']
|
|
train_pre_process_path = save_path['train_pre_process']
|
|
model_save_path = save_path['model_save_path']
|
|
|
|
# 解压数据集到original_path
|
|
with zipfile.ZipFile(data_path, 'r') as zip_ref:
|
|
zip_ref.extractall(original_path)
|
|
|
|
#image_path = './train_and_test'
|
|
|
|
pre_process(original_path, train_pre_process_path)
|
|
print("train start")
|
|
model_path = train(train_pre_process_path, model_save_path)
|
|
print("test start")
|
|
|
|
|
|
if opts.test_path is not None:
|
|
print("test start")
|
|
test_path = opts.test_path.get_full_path()
|
|
predict_and_plot(model_path, test_path, base_path)
|
|
print("test end")
|
|
else:
|
|
print("未指定测试数据集,跳过测试")
|
|
|
|
|
|
|
|
|
|
def to_parser():
|
|
return {
|
|
"预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
|
|
"训练": SubParser(TrainGlobalOptions, train_app, "训练"),
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
run_sp_and_exit(
|
|
to_parser(),
|
|
description="原子缺陷检测工具",
|
|
version="0.1.0",
|
|
exception_handler=default_minimal_exception_handler,
|
|
)
|
|
|
|
|