gnn模型预测和训练新增egdes_length参数
This commit is contained in:
parent
4b931ed2ff
commit
afc9ca3e3d
|
@ -11,6 +11,11 @@ from torch_geometric.data import InMemoryDataset, download_url, Dataset
|
|||
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
|
||||
import albumentations as A
|
||||
|
||||
try:
|
||||
A.check_version()
|
||||
except Exception as e:
|
||||
print(f"Version check failed: {e}")
|
||||
|
||||
def get_training_augmentation():
|
||||
train_transform = [
|
||||
A.GaussianBlur(),
|
||||
|
@ -203,7 +208,7 @@ def load_data_v2(json_path):
|
|||
return points, edge_index, labels, lights
|
||||
|
||||
|
||||
def load_data_v2_aug(json_path,aug):
|
||||
def load_data_v2_aug(json_path, aug, edges_length=35.5):
|
||||
# data_dict = {
|
||||
# 'V': 0,
|
||||
# 'S': 1,
|
||||
|
@ -247,7 +252,7 @@ def load_data_v2_aug(json_path,aug):
|
|||
labels = np.array([0 for item in json_data['shapes']], np.uint8)
|
||||
|
||||
lights = np.array([get_light(img, point) for point in points])
|
||||
edge_index = get_edge_index_delaunay(points, max_edge_length=35.5)
|
||||
edge_index = get_edge_index_delaunay(points, edges_length)
|
||||
|
||||
#line 35
|
||||
|
||||
|
@ -522,10 +527,12 @@ class AtomDataset_v2(InMemoryDataset):
|
|||
|
||||
|
||||
class AtomDataset_v2_aug(InMemoryDataset):
|
||||
def __init__(self, root, transform=None, pre_transform=None):
|
||||
def __init__(self, root, edges_length=35.5, transform=None, pre_transform=None):
|
||||
self.edges_length = edges_length
|
||||
super(AtomDataset_v2_aug, self).__init__(root, transform, pre_transform)
|
||||
self.data, self.slices = torch.load(self.processed_paths[0])
|
||||
|
||||
|
||||
@property
|
||||
def raw_file_names(self):
|
||||
return []
|
||||
|
@ -540,13 +547,13 @@ class AtomDataset_v2_aug(InMemoryDataset):
|
|||
def process(self):
|
||||
json_lst = glob.glob('{}/*.json'.format(self.root), recursive=True)
|
||||
print('Json data number: {}.'.format(len(json_lst)))
|
||||
|
||||
print("edges:", self.edges_length)
|
||||
data_lst = []
|
||||
# for i in range(10):
|
||||
for json_path in json_lst:
|
||||
base_name = json_path.split('/')[-1].split('.')[0]
|
||||
aug = get_training_augmentation() if json_path.split('/')[-2] == 'train' else get_validation_augmentation()
|
||||
points, edge_index, labels, lights = load_data_v2_aug(json_path,aug)
|
||||
points, edge_index, labels, lights = load_data_v2_aug(json_path,aug, self.edges_length)
|
||||
|
||||
for idx, center in enumerate(points):
|
||||
ner1 = find_ner(idx, edge_index)
|
||||
|
|
|
@ -29,6 +29,7 @@ from egnn_utils.save import save_results
|
|||
from resize import resize_images
|
||||
from metricse2e_vor import post_process
|
||||
from plot_view import plot2
|
||||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||||
# constants
|
||||
|
||||
DATASETS = '0'
|
||||
|
@ -139,7 +140,7 @@ def save_predict_result(predictions, save_path):
|
|||
|
||||
|
||||
|
||||
def predict(model_name, test_path, save_path):
|
||||
def predict(model_name, test_path, save_path, edges_length=35.5):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
# 获取本文件的绝对路径
|
||||
file_path = os.path.abspath(__file__)
|
||||
|
@ -151,9 +152,9 @@ def predict(model_name, test_path, save_path):
|
|||
print("train path: ", train_path)
|
||||
print("valid_path: ", valid_path)
|
||||
|
||||
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'))
|
||||
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'))
|
||||
test_dataset = AtomDataset_v2_aug(root=test_path)
|
||||
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'), edges_length=edges_length)
|
||||
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'), edges_length=edges_length)
|
||||
test_dataset = AtomDataset_v2_aug(root=test_path, edges_length=edges_length)
|
||||
#e2e_dataset = AtomDataset_v2_aug(root=test_path)
|
||||
|
||||
datamodule = LightningDataset(
|
||||
|
@ -231,16 +232,44 @@ def create_save_path(base_path: str) -> Dict[str, str]:
|
|||
# zip_ref.extractall(original_path)
|
||||
#
|
||||
|
||||
def predict_and_plot(model_name, test_path, save_base_path):
|
||||
|
||||
def gnn_generate_report(save_path: Dict[str, str], output_dir: str) -> None:
|
||||
img_elements = []
|
||||
#原始图片在遍历,路径在save_path["dataset"]
|
||||
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
|
||||
img_elements.append(AutoReportElement(
|
||||
path=os.path.relpath(img_path, output_dir),
|
||||
title=img_path.split("/")[-1],
|
||||
description=f'原始图片',
|
||||
))
|
||||
|
||||
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
|
||||
|
||||
img_elements = []
|
||||
#预测结果在遍历,
|
||||
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
|
||||
img_elements.append(AutoReportElement(
|
||||
path=os.path.relpath(img_path, output_dir),
|
||||
title=img_path.split('/')[-1],
|
||||
description=f'预测结果',
|
||||
) )
|
||||
|
||||
post_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
|
||||
|
||||
report = Report(title="预测结果", sections=[post_process_img_section, ori_img_section])
|
||||
report.save(output_dir)
|
||||
|
||||
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5):
|
||||
print("gnn model: ", model_name)
|
||||
#创建保存路径
|
||||
save_path = create_save_path(save_base_path)
|
||||
# 数据预处理
|
||||
resize_images(512,256,test_path)
|
||||
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'])
|
||||
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'], edges_length=edges_length)
|
||||
post_process(predict_result, test_path, save_path['gnn_predict_post_process'])
|
||||
plot2(img_folder=save_path['gnn_predict_post_process'], json_folder= save_path['gnn_predict_post_process'],
|
||||
output_folder=save_path['gnn_predict_result_view'])
|
||||
gnn_generate_report(save_path, save_base_path)
|
||||
|
||||
#
|
||||
# if __name__ == '__main__':
|
||||
|
|
|
@ -117,12 +117,12 @@ class PLModel(pl.LightningModule):
|
|||
# main
|
||||
|
||||
|
||||
def train(dataset_path, model_save_path):
|
||||
def train(dataset_path, model_save_path, edges_length=35.5):
|
||||
os.makedirs(model_save_path, exist_ok=True)
|
||||
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path))
|
||||
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path))
|
||||
test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
||||
e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
||||
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path), edges_length=edges_length)
|
||||
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path), edges_length=edges_length)
|
||||
#test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
||||
#e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
||||
|
||||
datamodule = LightningDataset(
|
||||
train_dataset,
|
||||
|
@ -182,6 +182,7 @@ def train(dataset_path, model_save_path):
|
|||
else:
|
||||
print("No checkpoint was saved by the checkpoint callback.")
|
||||
|
||||
return model_name
|
||||
#trainer.save_checkpoint(model_name)
|
||||
# # inference test
|
||||
# predictions = trainer.predict(
|
||||
|
@ -191,10 +192,10 @@ def train(dataset_path, model_save_path):
|
|||
# )
|
||||
# save_results(trainer.log_dir, predictions, 'test')
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset_path = './gnn_sv_train/'
|
||||
model_save_path = './train_result/train_model/'
|
||||
train(dataset_path, model_save_path)
|
||||
# if __name__ == '__main__':
|
||||
# dataset_path = './gnn_sv_train/'
|
||||
# model_save_path = './train_result/train_model/'
|
||||
# train(dataset_path, model_save_path)
|
||||
|
||||
# if __name__ == '__main__':
|
||||
# train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(DATA_PATH))
|
||||
|
|
|
@ -16,7 +16,7 @@ from dp.launching.report import Report, ReportSection, AutoReportElement
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from dp.launching.typing import BaseModel, Field, Int,Optional
|
||||
from dp.launching.typing import BaseModel, Field, Float, Optional
|
||||
from dp.launching.cli import to_runner, default_minimal_exception_handler
|
||||
from dp.launching.typing.io import InputFilePath, OutputDirectory
|
||||
|
||||
|
@ -29,6 +29,7 @@ from dp.launching.cli import (
|
|||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
|
||||
print(sys.path)
|
||||
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot
|
||||
from train_pl_v2_aug import train as egnn_train
|
||||
|
||||
class PredictOptions(BaseModel):
|
||||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||||
|
@ -45,6 +46,7 @@ class TrainOptions(BaseModel):
|
|||
)
|
||||
|
||||
class GnnPredictOptions(BaseModel):
|
||||
edges_length: Float = Field(default=35.5, description="连接节点边长")
|
||||
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
|
||||
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
|
||||
gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
|
||||
|
@ -52,6 +54,13 @@ class GnnPredictOptions(BaseModel):
|
|||
default="./result"
|
||||
) # default will be override after online
|
||||
|
||||
class GnnTrainOptions(BaseModel):
|
||||
edges_length: Float = Field(default=35.5, description="连接节点边长")
|
||||
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
|
||||
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
|
||||
output_dir: OutputDirectory = Field(
|
||||
default="./result"
|
||||
)
|
||||
|
||||
class PredictGlobalOptions(PredictOptions, BaseModel):
|
||||
...
|
||||
|
@ -62,6 +71,9 @@ class TrainGlobalOptions(TrainOptions, BaseModel):
|
|||
class GnnPredictGlobalOptions(GnnPredictOptions, BaseModel):
|
||||
...
|
||||
|
||||
class GnnTrainGlobalOptions(GnnTrainOptions, BaseModel):
|
||||
...
|
||||
|
||||
def predict_app(opts: PredictGlobalOptions) -> int:
|
||||
if opts.model_path is None:
|
||||
print("未指定模型路径,使用默认模型")
|
||||
|
@ -92,7 +104,7 @@ def gnn_predict_app(opts: GnnPredictGlobalOptions) -> int:
|
|||
|
||||
test_path = save_path["post_process"]
|
||||
print("gnn 预测开始")
|
||||
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path())
|
||||
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length)
|
||||
print("gnn 预测结束")
|
||||
|
||||
def train_app(opts: TrainGlobalOptions) -> int:
|
||||
|
@ -126,14 +138,34 @@ def train_app(opts: TrainGlobalOptions) -> int:
|
|||
else:
|
||||
print("未指定测试数据集,跳过测试")
|
||||
|
||||
def gnn_train_app(opts: GnnTrainGlobalOptions) -> int:
|
||||
train_path = opts.train_path.get_full_path()
|
||||
base_path = opts.output_dir.get_full_path()
|
||||
model_save_path = os.path.join(base_path, "train/model")
|
||||
train_data_path = os.path.join(base_path, "train/dataset")
|
||||
with zipfile.ZipFile(train_path, 'r') as zip_ref:
|
||||
zip_ref.extractall(train_data_path)
|
||||
|
||||
print("gnn 训练开始")
|
||||
model_name = egnn_train(train_data_path, model_save_path, opts.edges_length)
|
||||
print("gnn 训练结束")
|
||||
|
||||
|
||||
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
|
||||
save_path = predict_and_plot(model_path, opts.test_path.get_full_path(), opts.output_dir.get_full_path())
|
||||
|
||||
test_path = save_path["post_process"]
|
||||
print("gnn 预测开始")
|
||||
egnn_predict_and_plot(model_name, test_path, opts.output_dir.get_full_path(), opts.edges_length)
|
||||
print("gnn 预测结束")
|
||||
|
||||
|
||||
def to_parser():
|
||||
return {
|
||||
"原子缺陷预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
|
||||
"训练": SubParser(TrainGlobalOptions, train_app, "训练"),
|
||||
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测")
|
||||
"原子缺陷模型训练": SubParser(TrainGlobalOptions, train_app, "训练"),
|
||||
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测"),
|
||||
"gnn模型训练": SubParser(GnnTrainGlobalOptions, gnn_train_app, "训练"),
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Binary file not shown.
Loading…
Reference in New Issue