atom-predict/egnn_v2/egnn_core/model.py

23 lines
664 B
Python

import torch.nn as nn
from egnn_core.egnn_clean import EGNN
from model_type_dict import cz_label
class PL_EGNN(nn.Module):
def __init__(self, model_type=None):
if model_type == cz_label:
out_node_nf = 2
else:
out_node_nf = 3
super(PL_EGNN, self).__init__()
self.model = EGNN(
in_node_nf = 9,
hidden_nf = 16,
out_node_nf = out_node_nf,
n_layers = 2,
attention = True,
)
def forward(self, h, x, edges, return_features=False):
h, x = self.model(h, x, edges, return_features=return_features)
return h