atom-predict/egnn_v2/predict_pl_v2_aug.py

395 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import glob
import sys
import zipfile
import torch
import random
from torch.utils.data import DataLoader, Dataset
torch.set_float32_matmul_precision('high')
import numpy as np
from torch import nn
from torch.utils.data.sampler import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from typing import Dict
import pytorch_lightning as pl
from egnn_core.model import PL_EGNN
from egnn_core.data import *
from egnn_core.metrics import *
from egnn_core.aug import *
from torch_geometric.data.lightning import LightningDataset
from torch.utils.data import ConcatDataset
from egnn_utils.save import save_results
from resize import resize_images
from metricse2e_vor import post_process
from model_type_dict import norm_line_sv_label, cz_label
from plot_view import plot2, plot1
from dp.launching.report import Report, ReportSection, AutoReportElement
# constants
DATASETS = '0'
GPUS = 1
BS = 32
NW = 1
WD = 5e-4
LR = 0.01
RF = 0.9
EPOCHS = 1000
IN_CHANNELS = 3
SAVE_TOP_K = 5
EARLY_STOP = 5
EVERY_N_EPOCHES = 2
# DATA_PATH = '../../data/linesv/gnn_data/'
# DATA_PATH = '/home/gao/mouclear/cc/data_new/gnn_sv/'
# DATA_PATH = '/home/gao/mouclear/cc/final_todo/SV'
DATA_PATH = '/home/gao/mouclear/cc/final_todo/gnn_sv_train/'
# pytorch lightning module
class PLModel(pl.LightningModule):
def __init__(self, model_type=None):
super().__init__()
self.model = PL_EGNN(model_type=model_type)
self.model_type = model_type
if model_type == cz_label:
print("cz_label:", model_type)
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([2., 2.]))
else:
print("normal label")
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([1.,2., 2.]))
def forward(self, x):
out = self.model(x.x, x.edge_index, x.batch)
return out
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=LR, weight_decay=WD)
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_acc'
}
def training_step(self, train_batch, batch_idx):
pos, light, y, edge_index, batch, mask = (
train_batch.pos,
train_batch.light,
train_batch.label,
train_batch.edge_index,
train_batch.batch,
train_batch.mask,
)
pd = self.model(light, pos, edge_index)
loss = self.loss(pd[mask], y[mask])
train_acc = acc(pd[mask], y[mask])
self.log('train_loss', loss, batch_size=1)
self.log('train_acc', train_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
return loss
def validation_step(self, val_batch, batch_idx):
pos, light, y, edge_index, batch, mask = (
val_batch.pos,
val_batch.light,
val_batch.label,
val_batch.edge_index,
val_batch.batch,
val_batch.mask,
)
pd = self.model(light, pos, edge_index)
loss = self.loss(pd[mask], y[mask])
val_acc = acc(pd[mask], y[mask])
self.log('val_loss', loss, batch_size=1)
self.log('val_acc', val_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
def predict_step(self, batch, batch_idx):
pos, light, y, edge_index, batch, mask, name = (
batch.pos,
batch.light,
batch.label,
batch.edge_index,
batch.batch,
batch.mask,
batch.name
)
pd = self.model(light, pos, edge_index)
return pd[mask], y[mask], name
# main
def save_predict_result(predictions, save_path):
preds = torch.squeeze(torch.concat([item[0] for item in predictions])).numpy().tolist()
labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
name = [item[2] for item in predictions]
results = {
'name': name,
'pred': preds,
'label': labels,
}
results_json = json.dumps(results)
file_name = os.path.join(save_path, 'predict_result.json')
with open(file_name, 'w+') as f:
f.write(results_json)
return file_name
def predict(model_name, test_path, save_path, edges_length=35.5, model_type=None):
os.makedirs(save_path, exist_ok=True)
# 获取本文件的绝对路径
file_path = os.path.abspath(__file__)
# 获取本文件所在目录
dir_path = os.path.dirname(file_path)
print("base path: ", dir_path)
train_path = os.path.join(dir_path, 'gnn_sv_train/train/')
valid_path = os.path.join(dir_path, 'gnn_sv_valid/valid/')
print("train path: ", train_path)
print("valid_path: ", valid_path)
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'), edges_length=edges_length)
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'), edges_length=edges_length)
test_dataset = AtomDataset_v2_aug(root=test_path, edges_length=edges_length, model_type=model_type)
#e2e_dataset = AtomDataset_v2_aug(root=test_path)
datamodule = LightningDataset(
train_dataset,
valid_dataset,
test_dataset,
None,
batch_size=BS,
num_workers=NW,
)
model = PLModel(model_type=model_type)
logger = TensorBoardLogger(
name=DATASETS,
save_dir='logs',
)
checkpoint_callback = ModelCheckpoint(
every_n_epochs=EVERY_N_EPOCHES,
save_top_k=SAVE_TOP_K,
monitor='val_acc',
mode='max',
save_last=True,
filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
)
earlystop_callback = EarlyStopping(
monitor="val_acc",
mode="max",
min_delta=0.00,
patience=EARLY_STOP,
)
# training
trainer = pl.Trainer(
log_every_n_steps=1,
devices=GPUS,
max_epochs=EPOCHS,
logger=logger,
callbacks=[earlystop_callback, checkpoint_callback],
)
# trainer.fit(
# model,
# datamodule,
# )
# inference test
predictions = trainer.predict(
model,
datamodule.test_dataset,
ckpt_path=model_name,
)
return save_predict_result(predictions, save_path)
def create_save_path(base_path: str, model_type: str) -> Dict[str, str]:
# 定义需要创建的目录结构
paths = {
"gnn_predict_dataset": os.path.join(base_path, model_type,"gnn_predict_dataset"),
"gnn_predict_result" : os.path.join(base_path, model_type, 'gnn_predict_json'),
"gnn_predict_post_process": os.path.join(base_path, model_type, "gnn_predict_post_process"),
"gnn_predict_connect_view": os.path.join(base_path, model_type, "gnn_predict_connect_view"),
"gnn_predict_result_view": os.path.join(base_path, model_type, "gnn_predict_result_view"),
}
# 创建目录
for path in paths.values():
os.makedirs(path, exist_ok=True)
return paths
# def zip_dir(dir_path: str, zip_path: str) -> str:
# with zipfile.ZipFile(dir_path, 'r') as zip_ref:
# zip_ref.extractall(original_path)
#
def gnn_generate_report(save_path: Dict[str, str], output_dir: str) -> None:
img_elements = []
#原始图片在遍历路径在save_path["dataset"]
for img_path in glob.glob(os.path.join(output_dir, 'predict_dataset' + "/**/*.jpg"), recursive=True):
img_elements.append(AutoReportElement(
path=os.path.relpath(img_path, output_dir),
title=img_path.split("/")[-1],
description=f'原始图片',
))
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
img_elements = []
#预测结果在遍历,
for img_path in glob.glob(save_path["gnn_predict_result_view"] + "/*.jpg"):
print("gnn_predict_result_view, predict image path:", img_path)
img_elements.append(AutoReportElement(
path=os.path.relpath(img_path, output_dir),
title=img_path.split('/')[-1],
description=f'预测结果',
) )
post_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
img_elements = []
#预测结果在遍历,
for img_path in glob.glob(save_path["gnn_predict_connect_view"] + "/*.jpg"):
print("gnn_predict_connect_view, connect image path:", img_path)
img_elements.append(AutoReportElement(
path=os.path.relpath(img_path, output_dir),
title=img_path.split('/')[-1],
description=f'预测结果',
) )
connect_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
report = Report(title="预测结果", sections=[post_process_img_section, connect_process_img_section,ori_img_section])
report.save(output_dir)
def plot_connect_line(json_folder: str, output_folder: str):
for json_file in glob.glob(os.path.join(json_folder, '*.json')):
print("json_file:", json_file)
plot1(json_file, output_folder)
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5, model_type=norm_line_sv_label):
print("gnn model: ", model_name)
#创建保存路径
save_path = create_save_path(save_base_path, model_type)
# 数据预处理, resize图片
resize_images(512,256,test_path)
#预测结果
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'], edges_length=edges_length, model_type=model_type)
#预测结果后处理
post_process(predict_result, test_path, save_path['gnn_predict_post_process'], model_type)
#连接线绘制
plot_connect_line(save_path['gnn_predict_post_process'], save_path['gnn_predict_connect_view'])
#预测结果绘制
plot2(img_folder=save_path['gnn_predict_post_process'], json_folder= save_path['gnn_predict_post_process'],
output_folder=save_path['gnn_predict_result_view'])
#生成报告
return save_path
# if __name__ == '__main__':
# model_name = './model/cz.ckpt'
# #test_path = './gnn_sv_train/test/'
# test_path = './predict_post_process/'
# resize_images(512,256,test_path)
# save_path = './predict_result666/predict_json'
#
# model_type = cz_label
# predict_result =predict(model_name, test_path, save_path, model_type=model_type)
#
# post_process_save_path = './predict_result666/post_process/'
# post_process(predict_result, test_path,post_process_save_path, model_type=model_type)
# plot_save_path = './predict_result666/predict_result_view/'
# os.makedirs(plot_save_path, exist_ok=True)
# plot2(img_folder=post_process_save_path, json_folder= post_process_save_path, output_folder=plot_save_path)
# if __name__ == '__main__':
# train_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/train/')
# valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(DATA_PATH))
# test_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/test/')
# e2e_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/test/')
#
#
#
# datamodule = LightningDataset(
# train_dataset,
# valid_dataset,
# test_dataset,
# e2e_dataset,
# batch_size=BS,
# num_workers=NW,
# )
#
# model = PLModel()
#
# logger = TensorBoardLogger(
# name = DATASETS,
# save_dir = 'logs',
# )
#
# checkpoint_callback = ModelCheckpoint(
# every_n_epochs = EVERY_N_EPOCHES,
# save_top_k = SAVE_TOP_K,
# monitor = 'val_acc',
# mode = 'max',
# save_last = True,
# filename = '{epoch}-{val_loss:.2f}-{val_acc:.2f}'
# )
#
# earlystop_callback = EarlyStopping(
# monitor = "val_acc",
# mode = "max",
# min_delta = 0.00,
# patience = EARLY_STOP,
# )
#
# # training
# trainer = pl.Trainer(
# log_every_n_steps=1,
# devices = GPUS,
# max_epochs = EPOCHS,
# logger = logger,
# callbacks = [earlystop_callback, checkpoint_callback],
# )
#
# # trainer.fit(
# # model,
# # datamodule,
# # )
#
# # inference test
# predictions = trainer.predict(
# model,
# datamodule.test_dataset,
# ckpt_path = './logs/0/version_6/checkpoints/last.ckpt',
# )
# save_results(trainer.log_dir, predictions, 'test')
# inference e2e
# predictions = trainer.predict(
# model,
# datamodule.pred_dataset,
# ckpt_path = 'best',
# )
#
# save_results(trainer.log_dir, predictions, 'e2e')