atom-predict/egnn_v2/train_pl_v2_aug.py

268 lines
7.4 KiB
Python

import os
import json
import glob
import shutil
import torch
import random
torch.set_float32_matmul_precision('high')
import numpy as np
from torch import nn
from torch.utils.data.sampler import *
from torch.optim.lr_scheduler import ReduceLROnPlateau
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
import pytorch_lightning as pl
from egnn_core.model import PL_EGNN
from egnn_core.data import *
from egnn_core.metrics import *
from egnn_core.aug import *
from torch_geometric.data.lightning import LightningDataset
from torch.utils.data import ConcatDataset
from egnn_utils.save import save_results
from resize import resize_images
# constants
DATASETS = '0'
GPUS = 1
BS = 32
NW = 1
WD = 5e-4
LR = 0.01
RF = 0.9
EPOCHS = 1000
IN_CHANNELS = 3
SAVE_TOP_K = 5
EARLY_STOP = 5
EVERY_N_EPOCHES = 2
# DATA_PATH = '../../data/linesv/gnn_data/'
# DATA_PATH = '/home/gao/mouclear/cc/data_new/gnn_sv/'
# DATA_PATH = '/home/gao/mouclear/cc/final_todo/SV'
#DATA_PATH = '/home/gao/mouclear/cc/final_todo/gnn_sv_train/'
DATA_PATH = './gnn_sv_train/'
# pytorch lightning module
class PLModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = PL_EGNN()
self.loss = nn.CrossEntropyLoss(torch.FloatTensor([1., 2., 2.]))
def forward(self, x):
out = self.model(x.x, x.edge_index, x.batch)
return out
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=LR, weight_decay=WD)
scheduler = ReduceLROnPlateau(optimizer, factor=RF, mode='max', patience=2, min_lr=0, verbose=True)
return {
'optimizer': optimizer,
'lr_scheduler': scheduler,
'monitor': 'val_acc'
}
def training_step(self, train_batch, batch_idx):
pos, light, y, edge_index, batch, mask = (
train_batch.pos,
train_batch.light,
train_batch.label,
train_batch.edge_index,
train_batch.batch,
train_batch.mask,
)
pd = self.model(light, pos, edge_index)
loss = self.loss(pd[mask], y[mask])
train_acc = acc(pd[mask], y[mask])
self.log('train_loss', loss, batch_size=1)
self.log('train_acc', train_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
return loss
def validation_step(self, val_batch, batch_idx):
pos, light, y, edge_index, batch, mask = (
val_batch.pos,
val_batch.light,
val_batch.label,
val_batch.edge_index,
val_batch.batch,
val_batch.mask,
)
pd = self.model(light, pos, edge_index)
loss = self.loss(pd[mask], y[mask])
val_acc = acc(pd[mask], y[mask])
self.log('val_loss', loss, batch_size=1)
self.log('val_acc', val_acc, on_epoch=True, prog_bar=True, logger=True, batch_size=1)
def predict_step(self, batch, batch_idx):
pos, light, y, edge_index, batch, mask, name = (
batch.pos,
batch.light,
batch.label,
batch.edge_index,
batch.batch,
batch.mask,
batch.name
)
pd = self.model(light, pos, edge_index)
return pd[mask], y[mask], name
# main
def train(dataset_path, model_save_path, edges_length=35.5):
os.makedirs(model_save_path, exist_ok=True)
train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(dataset_path), edges_length=edges_length)
valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(dataset_path), edges_length=edges_length)
#test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
#e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(dataset_path))
datamodule = LightningDataset(
train_dataset,
valid_dataset,
None,
None,
batch_size=BS,
num_workers=NW,
)
model = PLModel()
logger = TensorBoardLogger(
name=DATASETS,
save_dir='logs',
)
checkpoint_callback = ModelCheckpoint(
every_n_epochs=EVERY_N_EPOCHES,
save_top_k=SAVE_TOP_K,
monitor='val_acc',
mode='max',
save_last=True,
filename='{epoch}-{val_loss:.2f}-{val_acc:.2f}'
)
earlystop_callback = EarlyStopping(
monitor="val_acc",
mode="max",
min_delta=0.00,
patience=EARLY_STOP,
)
# training
trainer = pl.Trainer(
log_every_n_steps=1,
devices=GPUS,
max_epochs=EPOCHS,
logger=logger,
callbacks=[earlystop_callback, checkpoint_callback],
)
trainer.fit(
model,
datamodule,
)
# model_name = os.path.join(model_save_path, 'last.ckpt')
# Ensure the checkpoint callback has saved the final model state
last_checkpoint_path = checkpoint_callback.last_model_path
if last_checkpoint_path:
model_name = os.path.join(model_save_path, 'last.ckpt')
# Copy the last checkpoint to the desired path
shutil.copy(last_checkpoint_path, model_name)
else:
print("No checkpoint was saved by the checkpoint callback.")
return model_name
#trainer.save_checkpoint(model_name)
# # inference test
# predictions = trainer.predict(
# model,
# datamodule.test_dataset,
# ckpt_path='best',
# )
# save_results(trainer.log_dir, predictions, 'test')
# if __name__ == '__main__':
# dataset_path = './gnn_sv_train/'
# model_save_path = './train_result/train_model/'
# train(dataset_path, model_save_path)
# if __name__ == '__main__':
# train_dataset = AtomDataset_v2_aug(root='{}/train/'.format(DATA_PATH))
# valid_dataset = AtomDataset_v2_aug(root='{}/valid/'.format(DATA_PATH))
# test_dataset = AtomDataset_v2_aug(root='{}/test/'.format(DATA_PATH))
# e2e_dataset = AtomDataset_v2_aug(root='{}/test/'.format(DATA_PATH))
#
# datamodule = LightningDataset(
# train_dataset,
# valid_dataset,
# test_dataset,
# e2e_dataset,
# batch_size=BS,
# num_workers=NW,
# )
#
# model = PLModel()
#
# logger = TensorBoardLogger(
# name = DATASETS,
# save_dir = 'logs',
# )
#
# checkpoint_callback = ModelCheckpoint(
# every_n_epochs = EVERY_N_EPOCHES,
# save_top_k = SAVE_TOP_K,
# monitor = 'val_acc',
# mode = 'max',
# save_last = True,
# filename = '{epoch}-{val_loss:.2f}-{val_acc:.2f}'
# )
#
# earlystop_callback = EarlyStopping(
# monitor = "val_acc",
# mode = "max",
# min_delta = 0.00,
# patience = EARLY_STOP,
# )
#
# # training
# trainer = pl.Trainer(
# log_every_n_steps=1,
# devices = GPUS,
# max_epochs = EPOCHS,
# logger = logger,
# callbacks = [earlystop_callback, checkpoint_callback],
# )
#
# trainer.fit(
# model,
# datamodule,
# )
#
# # inference test
# predictions = trainer.predict(
# model,
# datamodule.test_dataset,
# ckpt_path = 'best',
# )
# save_results(trainer.log_dir, predictions, 'test')
#
# # inference e2e
# # predictions = trainer.predict(
# # model,
# # datamodule.pred_dataset,
# # ckpt_path = 'best',
# # )
# #
# # save_results(trainer.log_dir, predictions, 'e2e')