material-demo/train_prop.py

482 lines
15 KiB
Python

import csv
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import time
from typing import Any, Dict, Union, Tuple
import pickle as pk
import os
import torch.distributed as dist
import torch
from jarvis.core.atoms import Atoms
from jarvis.db.jsonutils import dumpjson, loadjson
from torch import nn
import ignite
from tqdm import tqdm
from data import get_train_val_loaders
from models.config import TrainingConfig
import json
import pprint
from ignite.handlers import Checkpoint, DiskSaver, TerminateOnNan
from ignite.metrics import Loss, MeanAbsoluteError
from ignite.contrib.handlers import TensorboardLogger
from ignite.handlers.stores import EpochOutputStore
from ignite.handlers import EarlyStopping
from ignite.contrib.handlers.tensorboard_logger import (
global_step_from_engine,
)
from ignite.contrib.handlers.tqdm_logger import ProgressBar
from ignite.engine import (
Events,
Engine,
create_supervised_evaluator,
create_supervised_trainer,
)
from models.potnet import PotNet
import random
plt.switch_backend("agg")
device = torch.device("cuda")
def prepare_batch(
batch, device=None, non_blocking=False
):
"""Send batched dgl crystal graph to device."""
batch = (
batch.to(device, non_blocking=non_blocking),
batch.y.to(device, non_blocking=non_blocking),
)
return batch
def group_decay(model):
"""Omit weight decay from bias and batchnorm params."""
decay, no_decay = [], []
for name, p in model.named_parameters():
if "bias" in name or "bn" in name or "norm" in name:
no_decay.append(p)
else:
decay.append(p)
return [
{"params": decay},
{"params": no_decay, "weight_decay": 0},
]
def count_parameters(model):
total_params = 0
for parameter in model.parameters():
total_params += parameter.element_size() * parameter.nelement()
for parameter in model.buffers():
total_params += parameter.element_size() * parameter.nelement()
total_params = total_params / 1024 / 1024
print(f"Total Trainable Params: {total_params}")
return total_params
def setup_optimizer(params, config: TrainingConfig):
"""Set up optimizer for param groups."""
if config.optimizer == "adamw":
optimizer = torch.optim.AdamW(
params,
lr=config.learning_rate,
weight_decay=config.weight_decay,
)
elif config.optimizer == "sgd":
optimizer = torch.optim.SGD(
params,
lr=config.learning_rate,
momentum=0.9,
weight_decay=config.weight_decay,
)
return optimizer
def train_pyg(
config: Union[TrainingConfig, Dict[str, Any]],
data_root: str = None,
file_format: str = 'poscar',
checkpoint: str = None,
testing: bool = False,
train_val_test_loaders=None,
):
print(config)
config = TrainingConfig(**config)
if not os.path.exists(config.output_dir):
os.makedirs(config.output_dir)
checkpoint_dir = os.path.join(config.output_dir, config.checkpoint_dir)
deterministic = False
print("config:")
tmp = config.dict()
f = open(os.path.join(config.output_dir, "config.json"), "w")
f.write(json.dumps(tmp, indent=4))
f.close()
pprint.pprint(tmp) # , sort_dicts=False)
if config.random_seed is not None:
deterministic = True
ignite.utils.manual_seed(config.random_seed)
np.random.seed(config.random_seed)
torch.manual_seed(config.random_seed)
torch.cuda.manual_seed(config.random_seed)
torch.cuda.manual_seed_all(config.random_seed)
random.seed(config.random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
if data_root:
dataset_info = loadjson(os.path.join(data_root, "dataset_info.json"))
if "n_train" in dataset_info:
config.n_train = dataset_info['n_train']
if "n_val" in dataset_info:
config.n_val = dataset_info['n_val']
if "n_test" in dataset_info:
config.n_test = dataset_info['n_test']
if "train_ratio" in dataset_info:
config.train_ratio = dataset_info['train_ratio']
if "val_ratio" in dataset_info:
config.val_ratio = dataset_info['val_ratio']
if "test_ratio" in dataset_info:
config.test_ratio = dataset_info['test_ratio']
config.keep_data_order = True
config.target = "target"
id_prop_dat = os.path.join(data_root, "id_prop.csv")
with open(id_prop_dat, "r") as f:
reader = csv.reader(f)
data = [row for row in reader]
dataset_array = []
for i in data:
info = {}
file_name = i[0]
file_path = os.path.join(data_root, file_name)
if file_format == "poscar":
atoms = Atoms.from_poscar(file_path)
elif file_format == "cif":
atoms = Atoms.from_cif(file_path)
elif file_format == "xyz":
# Note using 500 angstrom as box size
atoms = Atoms.from_xyz(file_path, box_size=500)
elif file_format == "pdb":
# Note using 500 angstrom as box size
# Recommended install pytraj
# conda install -c ambermd pytraj
atoms = Atoms.from_pdb(file_path, max_lat=500)
else:
raise NotImplementedError(
"File format not implemented", file_format
)
info["atoms"] = atoms.to_dict()
info["jid"] = file_name
tmp = [float(j) for j in i[1:]] # float(i[1])
if len(tmp) == 1:
tmp = tmp[0]
info["target"] = tmp # float(i[1])
dataset_array.append(info)
else:
dataset_array = None
print('output_dir train', config.output_dir)
if not train_val_test_loaders:
# use input standardization for all real-valued feature sets
(
train_loader,
val_loader,
test_loader,
mean,
std
) = get_train_val_loaders(
dataset=config.dataset,
root=config.output_dir,
cachedir=config.cache_dir,
processdir=config.process_dir,
dataset_array=dataset_array,
target=config.target,
n_train=config.n_train,
n_val=config.n_val,
n_test=config.n_test,
train_ratio=config.train_ratio,
val_ratio=config.val_ratio,
test_ratio=config.test_ratio,
batch_size=config.batch_size,
atom_features=config.atom_features,
id_tag=config.id_tag,
pin_memory=config.pin_memory,
workers=config.num_workers,
normalize=config.normalize,
euclidean=config.euclidean,
cutoff=config.cutoff,
max_neighbors=config.max_neighbors,
infinite_funcs=config.infinite_funcs,
infinite_params=config.infinite_params,
R=config.R,
keep_data_order=config.keep_data_order,
)
else:
train_loader = train_val_test_loaders[0]
val_loader = train_val_test_loaders[1]
test_loader = train_val_test_loaders[2]
mean = 0.0
std = 1.0
# define network, optimizer, scheduler
_model = {
"potnet": PotNet,
}
config.model.euclidean = config.euclidean
net = _model.get(config.model.name)(config.model)
if checkpoint is not None:
net.load_state_dict(torch.load(checkpoint)["model"])
count_parameters(net)
net.to(device)
# group parameters to skip weight decay for bias and batchnorm
params = group_decay(net)
optimizer = setup_optimizer(params, config)
if config.scheduler == "none":
# always return multiplier of 1 (i.e. do nothing)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lambda epoch: 1.0
)
elif config.scheduler == "onecycle":
steps_per_epoch = len(train_loader)
pct_start = config.warmup_steps / (config.epochs * steps_per_epoch)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=config.learning_rate,
epochs=config.epochs,
steps_per_epoch=steps_per_epoch,
pct_start=pct_start,
)
elif config.scheduler == "step":
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=100, gamma=0.5
)
# select configured loss function
criteria = {
"mse": nn.MSELoss(),
"l1": nn.L1Loss(),
"poisson": nn.PoissonNLLLoss(log_input=False, full=True),
}
criterion = criteria[config.criterion]
# set up training engine and evaluators
metrics = {"loss": Loss(criterion), "mae": MeanAbsoluteError() * std, "neg_mae": -1.0 * MeanAbsoluteError() * std}
trainer = create_supervised_trainer(
net,
optimizer,
criterion,
prepare_batch=prepare_batch,
device=device,
deterministic=deterministic,
)
evaluator = create_supervised_evaluator(
net,
metrics=metrics,
prepare_batch=prepare_batch,
device=device,
)
train_evaluator = create_supervised_evaluator(
net,
metrics=metrics,
prepare_batch=prepare_batch,
device=device,
)
# ignite event handlers:
trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan())
# apply learning rate scheduler
trainer.add_event_handler(
Events.ITERATION_COMPLETED, lambda engine: scheduler.step()
)
if config.write_checkpoint:
# model checkpointing
to_save = {
"model": net,
"optimizer": optimizer,
"lr_scheduler": scheduler,
"trainer": trainer,
}
handler = Checkpoint(
to_save,
DiskSaver(checkpoint_dir, create_dir=True, require_empty=False),
n_saved=5,
score_name="neg_mae",
global_step_transform=lambda *_: trainer.state.epoch,
)
evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler)
if config.progress:
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {"loss": x})
history = {
"train": {m: [] for m in metrics.keys()},
"validation": {m: [] for m in metrics.keys()},
}
if config.store_outputs:
eos = EpochOutputStore()
eos.attach(evaluator)
train_eos = EpochOutputStore()
train_eos.attach(train_evaluator)
# collect evaluation performance
@trainer.on(Events.EPOCH_COMPLETED)
def log_results(engine):
train_evaluator.run(train_loader)
evaluator.run(val_loader)
tmetrics = train_evaluator.state.metrics
vmetrics = evaluator.state.metrics
for metric in metrics.keys():
tm = tmetrics[metric]
vm = vmetrics[metric]
if isinstance(tm, torch.Tensor):
tm = tm.cpu().numpy().tolist()
vm = vm.cpu().numpy().tolist()
history["train"][metric].append(tm)
history["validation"][metric].append(vm)
if config.store_outputs:
history["EOS"] = eos.data
history["trainEOS"] = train_eos.data
dumpjson(
filename=os.path.join(config.output_dir, config.model.name + "_" + config.target + "_history_val.json"),
data=history["validation"],
)
dumpjson(
filename=os.path.join(config.output_dir,
config.model.name + "_" + config.target + "_history_train.json"),
data=history["train"],
)
if config.progress:
pbar = ProgressBar()
pbar.log_message(f"Val_MAE: {vmetrics['mae']:.4f}")
pbar.log_message(f"Train_MAE: {tmetrics['mae']:.4f}")
if config.n_early_stopping is not None:
def default_score_fn(engine):
score = engine.state.metrics["mae"]
return score
es_handler = EarlyStopping(
patience=config.n_early_stopping,
score_function=default_score_fn,
trainer=trainer,
)
evaluator.add_event_handler(Events.EPOCH_COMPLETED, es_handler)
if config.log_tensorboard:
tb_logger = TensorboardLogger(
log_dir=os.path.join(config.output_dir, "tb_logs", "test")
)
for tag, evaluator in [
("training", train_evaluator),
("validation", evaluator),
]:
tb_logger.attach_output_handler(
evaluator,
event_name=Events.EPOCH_COMPLETED,
tag=tag,
metric_names=["loss", "mae", "neg_mae"],
global_step_transform=global_step_from_engine(trainer),
)
# train the model!
if not testing:
trainer.run(train_loader, max_epochs=config.epochs)
if config.log_tensorboard:
test_loss = evaluator.state.metrics["loss"]
tb_logger.writer.add_hparams(config, {"hparam/test_loss": test_loss})
tb_logger.close()
print("Testing!")
net.eval()
t1 = time.time()
if config.write_predictions:
f = open(
os.path.join(config.output_dir,
config.model.name + "_" + config.target + "_prediction_results_test_set.csv"),
"w",
)
f.write("id,target,prediction\n")
targets = []
predictions = []
with torch.no_grad():
ids = test_loader.dataset.ids # [test_loader.dataset.indices]
for dat, id in tqdm(zip(test_loader, ids)):
data = dat
target = dat.label
out_data = net(data.to(device))
out_data = out_data.cpu().numpy().tolist()
target = target.cpu().numpy().flatten().tolist()
if isinstance(out_data, list) and len(out_data) == 1:
out_data = out_data[0]
if len(target) == 1:
target = target[0]
out_data = out_data * std + mean
if config.write_predictions:
f.write("%s, %6f, %6f\n" % (id, target, out_data))
targets.append(target)
predictions.append(out_data)
if config.write_predictions:
f.close()
t2 = time.time()
print("Test time(s):", t2 - t1)
from sklearn.metrics import mean_absolute_error
print(
"Test MAE:",
mean_absolute_error(np.array(targets), np.array(predictions)),
)
return mean_absolute_error(np.array(targets), np.array(predictions))
def train_prop_model(config: Dict, data_root: str = None, checkpoint: str = None, testing: bool = False, file_format: str = 'poscar'):
if config["dataset"] == "megnet":
config["id_tag"] = "id"
if config["target"] == "e_form" or config["target"] == "gap pbe":
config["n_train"] = 60000
config["n_val"] = 5000
config["n_test"] = 4239
result = train_pyg(config, data_root=data_root, file_format=file_format, checkpoint=checkpoint, testing=testing)
return result