395 lines
13 KiB
Python
395 lines
13 KiB
Python
import os
|
||
import json
|
||
import glob
|
||
import sys
|
||
import zipfile
|
||
|
||
import torch
|
||
import random
|
||
from torch.utils.data import DataLoader, Dataset
|
||
torch.set_float32_matmul_precision('high')
|
||
|
||
import numpy as np
|
||
from torch import nn
|
||
from torch.utils.data.sampler import *
|
||
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
||
from pytorch_lightning.loggers import TensorBoardLogger
|
||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||
from typing import Dict
|
||
import pytorch_lightning as pl
|
||
|
||
from egnn_core.model import PL_EGNN
|
||
from egnn_core.data import *
|
||
from egnn_core.metrics import *
|
||
from egnn_core.aug import *
|
||
from torch_geometric.data.lightning import LightningDataset
|
||
from torch.utils.data import ConcatDataset
|
||
from egnn_utils.save import save_results
|
||
from resize import resize_images
|
||
from metricse2e_vor import post_process
|
||
from model_type_dict import norm_line_sv_label, cz_label
|
||
from plot_view import plot2, plot1
|
||
from dp.launching.report import Report, ReportSection, AutoReportElement
|
||
# constants
|
||
|
||
DATASETS = '0'
|
||
GPUS = 1
|
||
BS = 32
|
||
NW = 1
|
||
WD = 5e-4
|
||
LR = 0.01
|
||
RF = 0.9
|
||
EPOCHS = 1000
|
||
IN_CHANNELS = 3
|
||
SAVE_TOP_K = 5
|
||
EARLY_STOP = 5
|
||
EVERY_N_EPOCHES = 2
|
||
# DATA_PATH = '../../data/linesv/gnn_data/'
|
||
# DATA_PATH = '/home/gao/mouclear/cc/data_new/gnn_sv/'
|
||
# DATA_PATH = '/home/gao/mouclear/cc/final_todo/SV'
|
||
DATA_PATH = '/home/gao/mouclear/cc/final_todo/gnn_sv_train/'
|
||
|
||
# pytorch lightning module
|
||
|
||
class PLModel(pl.LightningModule):
|
||
def __init__(self, model_type=None):
|
||
super().__init__()
|
||
self.model = PL_EGNN(model_type=model_type)
|
||
self.model_type = model_type
|
||
if model_type == cz_label:
|
||
print("cz_label:", model_type)
|
||
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([2., 2.]))
|
||
else:
|
||
print("normal label")
|
||
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([1.,2., 2.]))
|
||
|
||
|
||
def forward(self, x):
|
||
out = self.model(x.x, x.edge_index, x.batch)
|
||
return out
|
||
|
||
def configure_optimizers(self):
|
||
optimizer = torch.optim.AdamW(self.parameters(), lr=LR, weight_decay=WD)
|
||
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
|
||
return {
|
||
'optimizer': optimizer,
|
||
'lr_scheduler': scheduler,
|
||
'monitor': 'val_acc'
|
||
}
|
||
|
||
def training_step(self, train_batch, batch_idx):
|
||
pos, light, y, edge_index, batch, mask = (
|
||
train_batch.pos,
|
||
train_batch.light,
|
||
train_batch.label,
|
||
train_batch.edge_index,
|
||
train_batch.batch,
|
||
train_batch.mask,
|
||
)
|
||
|
||
pd = self.model(light, pos, edge_index)
|
||
loss = self.loss(pd[mask], y[mask])
|
||
train_acc = acc(pd[mask], y[mask])
|
||
self.log('train_loss', loss, batch_size=1)
|
||
self.log('train_acc', train_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
|
||
|
||
return loss
|
||
|
||
def validation_step(self, val_batch, batch_idx):
|
||
pos, light, y, edge_index, batch, mask = (
|
||
val_batch.pos,
|
||
val_batch.light,
|
||
val_batch.label,
|
||
val_batch.edge_index,
|
||
val_batch.batch,
|
||
val_batch.mask,
|
||
)
|
||
|
||
pd = self.model(light, pos, edge_index)
|
||
loss = self.loss(pd[mask], y[mask])
|
||
val_acc = acc(pd[mask], y[mask])
|
||
self.log('val_loss', loss, batch_size=1)
|
||
self.log('val_acc', val_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
|
||
|
||
|
||
def predict_step(self, batch, batch_idx):
|
||
pos, light, y, edge_index, batch, mask, name = (
|
||
batch.pos,
|
||
batch.light,
|
||
batch.label,
|
||
batch.edge_index,
|
||
batch.batch,
|
||
batch.mask,
|
||
batch.name
|
||
)
|
||
|
||
pd = self.model(light, pos, edge_index)
|
||
return pd[mask], y[mask], name
|
||
|
||
# main
|
||
def save_predict_result(predictions, save_path):
|
||
preds = torch.squeeze(torch.concat([item[0] for item in predictions])).numpy().tolist()
|
||
labels = torch.squeeze(torch.concat([item[1] for item in predictions])).numpy().tolist()
|
||
name = [item[2] for item in predictions]
|
||
|
||
results = {
|
||
'name': name,
|
||
'pred': preds,
|
||
'label': labels,
|
||
}
|
||
|
||
results_json = json.dumps(results)
|
||
file_name = os.path.join(save_path, 'predict_result.json')
|
||
with open(file_name, 'w+') as f:
|
||
f.write(results_json)
|
||
|
||
return file_name
|
||
|
||
|
||
|
||
def predict(model_name, test_path, save_path, edges_length=35.5, model_type=None):
|
||
os.makedirs(save_path, exist_ok=True)
|
||
# 获取本文件的绝对路径
|
||
file_path = os.path.abspath(__file__)
|
||
# 获取本文件所在目录
|
||
dir_path = os.path.dirname(file_path)
|
||
print("base path: ", dir_path)
|
||
train_path = os.path.join(dir_path, 'gnn_sv_train/train/')
|
||
valid_path = os.path.join(dir_path, 'gnn_sv_valid/valid/')
|
||
print("train path: ", train_path)
|
||
print("valid_path: ", valid_path)
|
||
|
||
train_dataset = AtomDataset_v2_aug(root= os.path.join(dir_path, 'gnn_sv_train/train/'), edges_length=edges_length)
|
||
valid_dataset = AtomDataset_v2_aug(root=os.path.join(dir_path,'gnn_sv_train/valid/'), edges_length=edges_length)
|
||
test_dataset = AtomDataset_v2_aug(root=test_path, edges_length=edges_length, model_type=model_type)
|
||
#e2e_dataset = AtomDataset_v2_aug(root=test_path)
|
||
|
||
datamodule = LightningDataset(
|
||
train_dataset,
|
||
valid_dataset,
|
||
test_dataset,
|
||
None,
|
||
batch_size=BS,
|
||
num_workers=NW,
|
||
)
|
||
|
||
model = PLModel(model_type=model_type)
|
||
|
||
logger = TensorBoardLogger(
|
||
name=DATASETS,
|
||
save_dir='logs',
|
||
)
|
||
|
||
checkpoint_callback = ModelCheckpoint(
|
||
every_n_epochs=EVERY_N_EPOCHES,
|
||
save_top_k=SAVE_TOP_K,
|
||
monitor='val_acc',
|
||
mode='max',
|
||
save_last=True,
|
||
filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
|
||
)
|
||
|
||
earlystop_callback = EarlyStopping(
|
||
monitor="val_acc",
|
||
mode="max",
|
||
min_delta=0.00,
|
||
patience=EARLY_STOP,
|
||
)
|
||
|
||
# training
|
||
trainer = pl.Trainer(
|
||
log_every_n_steps=1,
|
||
devices=GPUS,
|
||
max_epochs=EPOCHS,
|
||
logger=logger,
|
||
callbacks=[earlystop_callback, checkpoint_callback],
|
||
)
|
||
|
||
# trainer.fit(
|
||
# model,
|
||
# datamodule,
|
||
# )
|
||
|
||
# inference test
|
||
predictions = trainer.predict(
|
||
model,
|
||
datamodule.test_dataset,
|
||
ckpt_path=model_name,
|
||
)
|
||
|
||
return save_predict_result(predictions, save_path)
|
||
|
||
def create_save_path(base_path: str, model_type: str) -> Dict[str, str]:
|
||
# 定义需要创建的目录结构
|
||
paths = {
|
||
"gnn_predict_dataset": os.path.join(base_path, model_type,"gnn_predict_dataset"),
|
||
"gnn_predict_result" : os.path.join(base_path, model_type, 'gnn_predict_json'),
|
||
"gnn_predict_post_process": os.path.join(base_path, model_type, "gnn_predict_post_process"),
|
||
"gnn_predict_connect_view": os.path.join(base_path, model_type, "gnn_predict_connect_view"),
|
||
"gnn_predict_result_view": os.path.join(base_path, model_type, "gnn_predict_result_view"),
|
||
}
|
||
|
||
# 创建目录
|
||
for path in paths.values():
|
||
os.makedirs(path, exist_ok=True)
|
||
|
||
return paths
|
||
|
||
# def zip_dir(dir_path: str, zip_path: str) -> str:
|
||
# with zipfile.ZipFile(dir_path, 'r') as zip_ref:
|
||
# zip_ref.extractall(original_path)
|
||
#
|
||
|
||
|
||
def gnn_generate_report(save_path: Dict[str, str], output_dir: str) -> None:
|
||
img_elements = []
|
||
#原始图片在遍历,路径在save_path["dataset"]
|
||
for img_path in glob.glob(os.path.join(output_dir, 'predict_dataset' + "/**/*.jpg"), recursive=True):
|
||
img_elements.append(AutoReportElement(
|
||
path=os.path.relpath(img_path, output_dir),
|
||
title=img_path.split("/")[-1],
|
||
description=f'原始图片',
|
||
))
|
||
|
||
ori_img_section = ReportSection(title="原始图片", ncols=2, elements=img_elements)
|
||
|
||
img_elements = []
|
||
#预测结果在遍历,
|
||
for img_path in glob.glob(save_path["gnn_predict_result_view"] + "/*.jpg"):
|
||
print("gnn_predict_result_view, predict image path:", img_path)
|
||
img_elements.append(AutoReportElement(
|
||
path=os.path.relpath(img_path, output_dir),
|
||
title=img_path.split('/')[-1],
|
||
description=f'预测结果',
|
||
) )
|
||
|
||
post_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
|
||
|
||
img_elements = []
|
||
#预测结果在遍历,
|
||
for img_path in glob.glob(save_path["gnn_predict_connect_view"] + "/*.jpg"):
|
||
print("gnn_predict_connect_view, connect image path:", img_path)
|
||
img_elements.append(AutoReportElement(
|
||
path=os.path.relpath(img_path, output_dir),
|
||
title=img_path.split('/')[-1],
|
||
description=f'预测结果',
|
||
) )
|
||
|
||
connect_process_img_section = ReportSection(title="预测结果", ncols=2, elements=img_elements)
|
||
|
||
report = Report(title="预测结果", sections=[post_process_img_section, connect_process_img_section,ori_img_section])
|
||
report.save(output_dir)
|
||
|
||
def plot_connect_line(json_folder: str, output_folder: str):
|
||
for json_file in glob.glob(os.path.join(json_folder, '*.json')):
|
||
print("json_file:", json_file)
|
||
plot1(json_file, output_folder)
|
||
|
||
def predict_and_plot(model_name, test_path, save_base_path, edges_length=35.5, model_type=norm_line_sv_label):
|
||
print("gnn model: ", model_name)
|
||
#创建保存路径
|
||
save_path = create_save_path(save_base_path, model_type)
|
||
# 数据预处理, resize图片
|
||
resize_images(512,256,test_path)
|
||
#预测结果
|
||
predict_result =predict(model_name, test_path, save_path['gnn_predict_result'], edges_length=edges_length, model_type=model_type)
|
||
#预测结果后处理
|
||
post_process(predict_result, test_path, save_path['gnn_predict_post_process'], model_type)
|
||
#连接线绘制
|
||
plot_connect_line(save_path['gnn_predict_post_process'], save_path['gnn_predict_connect_view'])
|
||
#预测结果绘制
|
||
plot2(img_folder=save_path['gnn_predict_post_process'], json_folder= save_path['gnn_predict_post_process'],
|
||
output_folder=save_path['gnn_predict_result_view'])
|
||
#生成报告
|
||
return save_path
|
||
|
||
|
||
|
||
# if __name__ == '__main__':
|
||
# model_name = './model/cz.ckpt'
|
||
# #test_path = './gnn_sv_train/test/'
|
||
# test_path = './predict_post_process/'
|
||
# resize_images(512,256,test_path)
|
||
# save_path = './predict_result666/predict_json'
|
||
#
|
||
# model_type = cz_label
|
||
# predict_result =predict(model_name, test_path, save_path, model_type=model_type)
|
||
#
|
||
# post_process_save_path = './predict_result666/post_process/'
|
||
# post_process(predict_result, test_path,post_process_save_path, model_type=model_type)
|
||
# plot_save_path = './predict_result666/predict_result_view/'
|
||
# os.makedirs(plot_save_path, exist_ok=True)
|
||
# plot2(img_folder=post_process_save_path, json_folder= post_process_save_path, output_folder=plot_save_path)
|
||
|
||
# if __name__ == '__main__':
|
||
# train_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/train/')
|
||
# valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(DATA_PATH))
|
||
# test_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/test/')
|
||
# e2e_dataset = AtomDataset_v2_aug(root='./gnn_sv_train/test/')
|
||
#
|
||
#
|
||
#
|
||
# datamodule = LightningDataset(
|
||
# train_dataset,
|
||
# valid_dataset,
|
||
# test_dataset,
|
||
# e2e_dataset,
|
||
# batch_size=BS,
|
||
# num_workers=NW,
|
||
# )
|
||
#
|
||
# model = PLModel()
|
||
#
|
||
# logger = TensorBoardLogger(
|
||
# name = DATASETS,
|
||
# save_dir = 'logs',
|
||
# )
|
||
#
|
||
# checkpoint_callback = ModelCheckpoint(
|
||
# every_n_epochs = EVERY_N_EPOCHES,
|
||
# save_top_k = SAVE_TOP_K,
|
||
# monitor = 'val_acc',
|
||
# mode = 'max',
|
||
# save_last = True,
|
||
# filename = '{epoch}-{val_loss:.2f}-{val_acc:.2f}'
|
||
# )
|
||
#
|
||
# earlystop_callback = EarlyStopping(
|
||
# monitor = "val_acc",
|
||
# mode = "max",
|
||
# min_delta = 0.00,
|
||
# patience = EARLY_STOP,
|
||
# )
|
||
#
|
||
# # training
|
||
# trainer = pl.Trainer(
|
||
# log_every_n_steps=1,
|
||
# devices = GPUS,
|
||
# max_epochs = EPOCHS,
|
||
# logger = logger,
|
||
# callbacks = [earlystop_callback, checkpoint_callback],
|
||
# )
|
||
#
|
||
# # trainer.fit(
|
||
# # model,
|
||
# # datamodule,
|
||
# # )
|
||
#
|
||
# # inference test
|
||
# predictions = trainer.predict(
|
||
# model,
|
||
# datamodule.test_dataset,
|
||
# ckpt_path = './logs/0/version_6/checkpoints/last.ckpt',
|
||
# )
|
||
# save_results(trainer.log_dir, predictions, 'test')
|
||
|
||
# inference e2e
|
||
# predictions = trainer.predict(
|
||
# model,
|
||
# datamodule.pred_dataset,
|
||
# ckpt_path = 'best',
|
||
# )
|
||
#
|
||
# save_results(trainer.log_dir, predictions, 'e2e')
|