material-demo/models/potnet.py

159 lines
5.5 KiB
Python

import math
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pydantic.typing import Literal
from torch_geometric.nn import Linear, MessagePassing, global_mean_pool
from torch_geometric.nn.models.schnet import ShiftedSoftplus
from models.base import BaseSettings
from models.transformer import TransformerConv
from models.utils import RBFExpansion
class PotNetConfig(BaseSettings):
name: Literal["potnet"]
conv_layers: int = 3
atom_input_features: int = 92
inf_edge_features: int = 64
fc_features: int = 256
output_dim: int = 256
output_features: int = 1
rbf_min = -4.0
rbf_max = 4.0
potentials = []
euclidean = False
charge_map = False
transformer = False
class Config:
"""Configure model settings behavior."""
env_prefix = "jv_model"
class PotNetConv(MessagePassing):
def __init__(self, fc_features):
super(PotNetConv, self).__init__(node_dim=0)
self.bn = nn.BatchNorm1d(fc_features)
self.bn_interaction = nn.BatchNorm1d(fc_features)
self.nonlinear_full = nn.Sequential(
nn.Linear(3 * fc_features, fc_features),
nn.SiLU(),
nn.Linear(fc_features, fc_features)
)
self.nonlinear = nn.Sequential(
nn.Linear(3 * fc_features, fc_features),
nn.SiLU(),
nn.Linear(fc_features, fc_features),
)
def forward(self, x, edge_index, edge_attr):
out = self.propagate(
edge_index, x=x, edge_attr=edge_attr, size=(x.size(0), x.size(0))
)
return F.relu(x + self.bn(out))
def message(self, x_i, x_j, edge_attr, index):
score = torch.sigmoid(self.bn_interaction(self.nonlinear_full(torch.cat((x_i, x_j, edge_attr), dim=1))))
return score * self.nonlinear(torch.cat((x_i, x_j, edge_attr), dim=1))
class PotNet(nn.Module):
def __init__(self, config: PotNetConfig = PotNetConfig(name="potnet")):
super().__init__()
self.config = config
if not config.charge_map:
self.atom_embedding = nn.Linear(
config.atom_input_features, config.fc_features
)
else:
self.atom_embedding = nn.Linear(
config.atom_input_features + 10, config.fc_features
)
self.edge_embedding = nn.Sequential(
RBFExpansion(
vmin=config.rbf_min,
vmax=config.rbf_max,
bins=config.fc_features,
),
nn.Linear(config.fc_features, config.fc_features),
nn.SiLU(),
)
if not self.config.euclidean:
self.inf_edge_embedding = RBFExpansion(
vmin=config.rbf_min,
vmax=config.rbf_max,
bins=config.inf_edge_features,
type='multiquadric'
)
self.infinite_linear = nn.Linear(config.inf_edge_features, config.fc_features)
self.infinite_bn = nn.BatchNorm1d(config.fc_features)
self.conv_layers = nn.ModuleList(
[
PotNetConv(config.fc_features)
for _ in range(config.conv_layers)
]
)
if not config.euclidean and config.transformer:
self.transformer_conv_layers = nn.ModuleList(
[
TransformerConv(config.fc_features, config.fc_features)
for _ in range(config.conv_layers)
]
)
self.fc = nn.Sequential(
nn.Linear(config.fc_features, config.fc_features), ShiftedSoftplus()
)
self.fc_out = nn.Linear(config.output_dim, config.output_features)
def forward(self, data, print_data=False):
"""CGCNN function mapping graph to outputs."""
# fixed edge features: RBF-expanded bondlengths
edge_index = data.edge_index
if self.config.euclidean:
edge_features = self.edge_embedding(data.edge_attr)
else:
edge_features = self.edge_embedding(-0.75 / data.edge_attr)
if not self.config.euclidean:
inf_edge_index = data.inf_edge_index
inf_feat = sum([data.inf_edge_attr[:, i] * pot for i, pot in enumerate(self.config.potentials)])
inf_edge_features = self.inf_edge_embedding(inf_feat)
inf_edge_features = self.infinite_bn(F.softplus(self.infinite_linear(inf_edge_features)))
# initial node features: atom feature network...
if self.config.charge_map:
node_features = self.atom_embedding(torch.cat([data.x, data.g_feats], -1))
else:
node_features = self.atom_embedding(data.x)
if not self.config.euclidean and not self.config.transformer:
edge_index = torch.cat([data.edge_index, inf_edge_index], 1)
edge_features = torch.cat([edge_features, inf_edge_features], 0)
for i in range(self.config.conv_layers):
if not self.config.euclidean and self.config.transformer:
local_node_features = self.conv_layers[i](node_features, edge_index, edge_features)
inf_node_features = self.transformer_conv_layers[i](node_features, inf_edge_index, inf_edge_features)
node_features = local_node_features + inf_node_features
else:
node_features = self.conv_layers[i](node_features, edge_index, edge_features)
features = global_mean_pool(node_features, data.batch)
features = self.fc(features)
return torch.squeeze(self.fc_out(features))