268 lines
7.4 KiB
Python
268 lines
7.4 KiB
Python
import os
|
|
import json
|
|
import glob
|
|
import shutil
|
|
|
|
import torch
|
|
import random
|
|
torch.set_float32_matmul_precision('high')
|
|
|
|
import numpy as np
|
|
from torch import nn
|
|
from torch.utils.data.sampler import *
|
|
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
|
from pytorch_lightning.loggers import TensorBoardLogger
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
import pytorch_lightning as pl
|
|
|
|
from egnn_core.model import PL_EGNN
|
|
from egnn_core.data import *
|
|
from egnn_core.metrics import *
|
|
from egnn_core.aug import *
|
|
from torch_geometric.data.lightning import LightningDataset
|
|
from torch.utils.data import ConcatDataset
|
|
from egnn_utils.save import save_results
|
|
from resize import resize_images
|
|
|
|
# constants
|
|
|
|
DATASETS = '0'
|
|
GPUS = 1
|
|
BS = 32
|
|
NW = 1
|
|
WD = 5e-4
|
|
LR = 0.01
|
|
RF = 0.9
|
|
EPOCHS = 1000
|
|
IN_CHANNELS = 3
|
|
SAVE_TOP_K = 5
|
|
EARLY_STOP = 5
|
|
EVERY_N_EPOCHES = 2
|
|
# DATA_PATH = '../../data/linesv/gnn_data/'
|
|
# DATA_PATH = '/home/gao/mouclear/cc/data_new/gnn_sv/'
|
|
# DATA_PATH = '/home/gao/mouclear/cc/final_todo/SV'
|
|
#DATA_PATH = '/home/gao/mouclear/cc/final_todo/gnn_sv_train/'
|
|
DATA_PATH = './gnn_sv_train/'
|
|
# pytorch lightning module
|
|
|
|
class PLModel(pl.LightningModule):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.model = PL_EGNN()
|
|
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([1., 2., 2.]))
|
|
|
|
def forward(self, x):
|
|
out = self.model(x.x, x.edge_index, x.batch)
|
|
return out
|
|
|
|
def configure_optimizers(self):
|
|
optimizer = torch.optim.AdamW(self.parameters(), lr=LR, weight_decay=WD)
|
|
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
|
|
return {
|
|
'optimizer': optimizer,
|
|
'lr_scheduler': scheduler,
|
|
'monitor': 'val_acc'
|
|
}
|
|
|
|
def training_step(self, train_batch, batch_idx):
|
|
pos, light, y, edge_index, batch, mask = (
|
|
train_batch.pos,
|
|
train_batch.light,
|
|
train_batch.label,
|
|
train_batch.edge_index,
|
|
train_batch.batch,
|
|
train_batch.mask,
|
|
)
|
|
|
|
pd = self.model(light, pos, edge_index)
|
|
loss = self.loss(pd[mask], y[mask])
|
|
train_acc = acc(pd[mask], y[mask])
|
|
self.log('train_loss', loss, batch_size=1)
|
|
self.log('train_acc', train_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
|
|
|
|
return loss
|
|
|
|
def validation_step(self, val_batch, batch_idx):
|
|
pos, light, y, edge_index, batch, mask = (
|
|
val_batch.pos,
|
|
val_batch.light,
|
|
val_batch.label,
|
|
val_batch.edge_index,
|
|
val_batch.batch,
|
|
val_batch.mask,
|
|
)
|
|
|
|
pd = self.model(light, pos, edge_index)
|
|
loss = self.loss(pd[mask], y[mask])
|
|
val_acc = acc(pd[mask], y[mask])
|
|
self.log('val_loss', loss, batch_size=1)
|
|
self.log('val_acc', val_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
|
|
|
|
|
|
def predict_step(self, batch, batch_idx):
|
|
pos, light, y, edge_index, batch, mask, name = (
|
|
batch.pos,
|
|
batch.light,
|
|
batch.label,
|
|
batch.edge_index,
|
|
batch.batch,
|
|
batch.mask,
|
|
batch.name
|
|
)
|
|
|
|
pd = self.model(light, pos, edge_index)
|
|
return pd[mask], y[mask], name
|
|
|
|
# main
|
|
|
|
|
|
def train(dataset_path, model_save_path, edges_length=35.5):
|
|
os.makedirs(model_save_path, exist_ok=True)
|
|
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path), edges_length=edges_length)
|
|
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path), edges_length=edges_length)
|
|
#test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
|
#e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
|
|
|
|
datamodule = LightningDataset(
|
|
train_dataset,
|
|
valid_dataset,
|
|
None,
|
|
None,
|
|
batch_size=BS,
|
|
num_workers=NW,
|
|
)
|
|
|
|
model = PLModel()
|
|
|
|
logger = TensorBoardLogger(
|
|
name=DATASETS,
|
|
save_dir='logs',
|
|
)
|
|
|
|
checkpoint_callback = ModelCheckpoint(
|
|
every_n_epochs=EVERY_N_EPOCHES,
|
|
save_top_k=SAVE_TOP_K,
|
|
monitor='val_acc',
|
|
mode='max',
|
|
save_last=True,
|
|
filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
|
|
)
|
|
|
|
earlystop_callback = EarlyStopping(
|
|
monitor="val_acc",
|
|
mode="max",
|
|
min_delta=0.00,
|
|
patience=EARLY_STOP,
|
|
)
|
|
|
|
# training
|
|
trainer = pl.Trainer(
|
|
log_every_n_steps=1,
|
|
devices=GPUS,
|
|
max_epochs=EPOCHS,
|
|
logger=logger,
|
|
callbacks=[earlystop_callback, checkpoint_callback],
|
|
)
|
|
|
|
trainer.fit(
|
|
model,
|
|
datamodule,
|
|
)
|
|
|
|
# model_name = os.path.join(model_save_path, 'last.ckpt')
|
|
|
|
# Ensure the checkpoint callback has saved the final model state
|
|
last_checkpoint_path = checkpoint_callback.last_model_path
|
|
|
|
if last_checkpoint_path:
|
|
model_name = os.path.join(model_save_path, 'last.ckpt')
|
|
# Copy the last checkpoint to the desired path
|
|
shutil.copy(last_checkpoint_path, model_name)
|
|
else:
|
|
print("No checkpoint was saved by the checkpoint callback.")
|
|
|
|
return model_name
|
|
#trainer.save_checkpoint(model_name)
|
|
# # inference test
|
|
# predictions = trainer.predict(
|
|
# model,
|
|
# datamodule.test_dataset,
|
|
# ckpt_path='best',
|
|
# )
|
|
# save_results(trainer.log_dir, predictions, 'test')
|
|
|
|
# if __name__ == '__main__':
|
|
# dataset_path = './gnn_sv_train/'
|
|
# model_save_path = './train_result/train_model/'
|
|
# train(dataset_path, model_save_path)
|
|
|
|
# if __name__ == '__main__':
|
|
# train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(DATA_PATH))
|
|
# valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(DATA_PATH))
|
|
# test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(DATA_PATH))
|
|
# e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(DATA_PATH))
|
|
#
|
|
# datamodule = LightningDataset(
|
|
# train_dataset,
|
|
# valid_dataset,
|
|
# test_dataset,
|
|
# e2e_dataset,
|
|
# batch_size=BS,
|
|
# num_workers=NW,
|
|
# )
|
|
#
|
|
# model = PLModel()
|
|
#
|
|
# logger = TensorBoardLogger(
|
|
# name = DATASETS,
|
|
# save_dir = 'logs',
|
|
# )
|
|
#
|
|
# checkpoint_callback = ModelCheckpoint(
|
|
# every_n_epochs = EVERY_N_EPOCHES,
|
|
# save_top_k = SAVE_TOP_K,
|
|
# monitor = 'val_acc',
|
|
# mode = 'max',
|
|
# save_last = True,
|
|
# filename = '{epoch}-{val_loss:.2f}-{val_acc:.2f}'
|
|
# )
|
|
#
|
|
# earlystop_callback = EarlyStopping(
|
|
# monitor = "val_acc",
|
|
# mode = "max",
|
|
# min_delta = 0.00,
|
|
# patience = EARLY_STOP,
|
|
# )
|
|
#
|
|
# # training
|
|
# trainer = pl.Trainer(
|
|
# log_every_n_steps=1,
|
|
# devices = GPUS,
|
|
# max_epochs = EPOCHS,
|
|
# logger = logger,
|
|
# callbacks = [earlystop_callback, checkpoint_callback],
|
|
# )
|
|
#
|
|
# trainer.fit(
|
|
# model,
|
|
# datamodule,
|
|
# )
|
|
#
|
|
# # inference test
|
|
# predictions = trainer.predict(
|
|
# model,
|
|
# datamodule.test_dataset,
|
|
# ckpt_path = 'best',
|
|
# )
|
|
# save_results(trainer.log_dir, predictions, 'test')
|
|
#
|
|
# # inference e2e
|
|
# # predictions = trainer.predict(
|
|
# # model,
|
|
# # datamodule.pred_dataset,
|
|
# # ckpt_path = 'best',
|
|
# # )
|
|
# #
|
|
# # save_results(trainer.log_dir, predictions, 'e2e')
|