gnn模型预测和训练新增egdes_length参数

This commit is contained in:
somunslotus 2024-07-23 15:59:58 +08:00
parent 4b931ed2ff
commit afc9ca3e3d
5 changed files with 93 additions and 24 deletions

View File

@ -11,6 +11,11 @@ from torch_geometric.data import InMemoryDataset, download_url, Dataset
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
import albumentations as A
try:
A.check_version()
except Exception as e:
print(f"Version check failed: {e}")
def get_training_augmentation():
train_transform = [
A.GaussianBlur(),
@ -203,7 +208,7 @@ def load_data_v2(json_path):
return points, edge_index, labels, lights
def load_data_v2_aug(json_path,aug):
def load_data_v2_aug(json_path, aug, edges_length=35.5):
# data_dict = {
# 'V': 0,
# 'S': 1,
@ -247,7 +252,7 @@ def load_data_v2_aug(json_path,aug):
labels = np.array([0 for item in json_data['shapes']], np.uint8)
lights = np.array([get_light(img, point) for point in points])
edge_index = get_edge_index_delaunay(points, max_edge_length=35.5)
edge_index = get_edge_index_delaunay(points, edges_length)
#line 35
@ -522,10 +527,12 @@ class AtomDataset_v2(InMemoryDataset):
class AtomDataset_v2_aug(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
def __init__(self, root, edges_length=35.5, transform=None, pre_transform=None):
self.edges_length = edges_length
super(AtomDataset_v2_aug, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return []
@ -540,13 +547,13 @@ class AtomDataset_v2_aug(InMemoryDataset):
def process(self):
json_lst = glob.glob('{}/*.json'.format(self.root), recursive=True)
print('Json data number: {}.'.format(len(json_lst)))
print("edges:", self.edges_length)
data_lst = []
# for i in range(10):
for json_path in json_lst:
base_name = json_path.split('/')[-1].split('.')[0]
aug = get_training_augmentation() if json_path.split('/')[-2] == 'train' else get_validation_augmentation()
points, edge_index, labels, lights = load_data_v2_aug(json_path,aug)
points, edge_index, labels, lights = load_data_v2_aug(json_path,aug, self.edges_length)
for idx, center in enumerate(points):
ner1 = find_ner(idx, edge_index)

View File

@ -29,6 +29,7 @@ from egnn_utils.save import save_results
from resize import resize_images
from metricse2e_vor import post_process
from plot_view import plot2
from dp.launching.report import Report, ReportSection, AutoReportElement
# constants
DATASETS = '0'
@ -139,7 +140,7 @@ def save_predict_result(predictions, save_path):
def predict(model_name, test_path, save_path):
def predict(model_name, test_path, save_path, edges_length=35.5):
os.makedirs(save_path, exist_ok=True)
# 获取本文件的绝对路径
file_path = os.path.abspath(__file__)
@ -151,9 +152,9 @@ def predict(model_name, test_path, save_path):
print("train path: ", train_path)
print("valid_path: ", valid_path)
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'))
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'))
test_dataset = AtomDataset_v2_aug(root=test_path)
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'), edges_length=edges_length)
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'), edges_length=edges_length)
test_dataset = AtomDataset_v2_aug(root=test_path, edges_length=edges_length)
#e2e_dataset = AtomDataset_v2_aug(root=test_path)
datamodule = LightningDataset(
@ -231,16 +232,44 @@ def create_save_path(base_path: str) -> Dict[str, str]:
# zip_ref.extractall(original_path)
#
def predict_and_plot(model_name, test_path, save_base_path):
def gnn_generate_report(save_path: Dict[str, str], output_dir: str) -> None:
img_elements = []
#原始图片在遍历路径在save_path["dataset"]
for img_path in glob.glob(save_path["dataset"] + "/**/*.jpg", recursive=True):
img_elements.append(AutoReportElement(
path=os.path.relpath(img_path, output_dir),
title=img_path.split("/")[-1],
description=f'原始图片',
))
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
img_elements = []
#预测结果在遍历,
for img_path in glob.glob(save_path["plot_view"] + "/*.jpg"):
img_elements.append(AutoReportElement(
path=os.path.relpath(img_path, output_dir),
title=img_path.split('/')[-1],
description=f'预测结果',
) )
post_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
report = Report(title="预测结果", sections=[post_process_img_section, ori_img_section])
report.save(output_dir)
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5):
print("gnn model: ", model_name)
#创建保存路径
save_path = create_save_path(save_base_path)
# 数据预处理
resize_images(512,256,test_path)
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'])
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'], edges_length=edges_length)
post_process(predict_result, test_path, save_path['gnn_predict_post_process'])
plot2(img_folder=save_path['gnn_predict_post_process'], json_folder= save_path['gnn_predict_post_process'],
output_folder=save_path['gnn_predict_result_view'])
gnn_generate_report(save_path, save_base_path)
#
# if __name__ == '__main__':

View File

@ -117,12 +117,12 @@ class PLModel(pl.LightningModule):
# main
def train(dataset_path, model_save_path):
def train(dataset_path, model_save_path, edges_length=35.5):
os.makedirs(model_save_path, exist_ok=True)
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path))
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path))
test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path), edges_length=edges_length)
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path), edges_length=edges_length)
#test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
#e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
datamodule = LightningDataset(
train_dataset,
@ -182,6 +182,7 @@ def train(dataset_path, model_save_path):
else:
print("No checkpoint was saved by the checkpoint callback.")
return model_name
#trainer.save_checkpoint(model_name)
# # inference test
# predictions = trainer.predict(
@ -191,10 +192,10 @@ def train(dataset_path, model_save_path):
# )
# save_results(trainer.log_dir, predictions, 'test')
if __name__ == '__main__':
dataset_path = './gnn_sv_train/'
model_save_path = './train_result/train_model/'
train(dataset_path, model_save_path)
# if __name__ == '__main__':
# dataset_path = './gnn_sv_train/'
# model_save_path = './train_result/train_model/'
# train(dataset_path, model_save_path)
# if __name__ == '__main__':
# train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(DATA_PATH))

View File

@ -16,7 +16,7 @@ from dp.launching.report import Report, ReportSection, AutoReportElement
import sys
from pathlib import Path
from dp.launching.typing import BaseModel, Field, Int,Optional
from dp.launching.typing import BaseModel, Field, Float, Optional
from dp.launching.cli import to_runner, default_minimal_exception_handler
from dp.launching.typing.io import InputFilePath, OutputDirectory
@ -29,6 +29,7 @@ from dp.launching.cli import (
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../egnn_v2')))
print(sys.path)
from predict_pl_v2_aug import predict_and_plot as egnn_predict_and_plot
from train_pl_v2_aug import train as egnn_train
class PredictOptions(BaseModel):
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
@ -45,6 +46,7 @@ class TrainOptions(BaseModel):
)
class GnnPredictOptions(BaseModel):
edges_length: Float = Field(default=35.5, description="连接节点边长")
data_path: InputFilePath = Field(..., ftypes=['zip'], description="测试的数据集")
model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="原子缺陷检测模型")
gnn_model_path: Optional[InputFilePath] = Field(ftypes=['ckpt'], description="gnn模型")
@ -52,6 +54,13 @@ class GnnPredictOptions(BaseModel):
default="./result"
) # default will be override after online
class GnnTrainOptions(BaseModel):
edges_length: Float = Field(default=35.5, description="连接节点边长")
train_path: InputFilePath = Field(..., ftypes=['zip'], description="训练数据集")
test_path: Optional[InputFilePath] = Field(ftypes=['zip'], description="测试使用的模型")
output_dir: OutputDirectory = Field(
default="./result"
)
class PredictGlobalOptions(PredictOptions, BaseModel):
...
@ -62,6 +71,9 @@ class TrainGlobalOptions(TrainOptions, BaseModel):
class GnnPredictGlobalOptions(GnnPredictOptions, BaseModel):
...
class GnnTrainGlobalOptions(GnnTrainOptions, BaseModel):
...
def predict_app(opts: PredictGlobalOptions) -> int:
if opts.model_path is None:
print("未指定模型路径,使用默认模型")
@ -92,7 +104,7 @@ def gnn_predict_app(opts: GnnPredictGlobalOptions) -> int:
test_path = save_path["post_process"]
print("gnn 预测开始")
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path())
egnn_predict_and_plot(gnn_model_path, test_path, opts.output_dir.get_full_path(), opts.edges_length)
print("gnn 预测结束")
def train_app(opts: TrainGlobalOptions) -> int:
@ -126,14 +138,34 @@ def train_app(opts: TrainGlobalOptions) -> int:
else:
print("未指定测试数据集,跳过测试")
def gnn_train_app(opts: GnnTrainGlobalOptions) -> int:
train_path = opts.train_path.get_full_path()
base_path = opts.output_dir.get_full_path()
model_save_path = os.path.join(base_path, "train/model")
train_data_path = os.path.join(base_path, "train/dataset")
with zipfile.ZipFile(train_path, 'r') as zip_ref:
zip_ref.extractall(train_data_path)
print("gnn 训练开始")
model_name = egnn_train(train_data_path, model_save_path, opts.edges_length)
print("gnn 训练结束")
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model/last.ckpt")
save_path = predict_and_plot(model_path, opts.test_path.get_full_path(), opts.output_dir.get_full_path())
test_path = save_path["post_process"]
print("gnn 预测开始")
egnn_predict_and_plot(model_name, test_path, opts.output_dir.get_full_path(), opts.edges_length)
print("gnn 预测结束")
def to_parser():
return {
"原子缺陷预测": SubParser(PredictGlobalOptions, predict_app, "预测"),
"训练": SubParser(TrainGlobalOptions, train_app, "训练"),
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测")
"原子缺陷模型训练": SubParser(TrainGlobalOptions, train_app, "训练"),
"gnn预测" : SubParser(GnnPredictGlobalOptions, gnn_predict_app, "gnn预测"),
"gnn模型训练": SubParser(GnnTrainGlobalOptions, gnn_train_app, "训练"),
}
if __name__ == "__main__":

BIN
msunet/gnn_sv_train.zip Normal file

Binary file not shown.