23 lines
664 B
Python
23 lines
664 B
Python
import torch.nn as nn
|
|
from egnn_core.egnn_clean import EGNN
|
|
from model_type_dict import cz_label
|
|
|
|
class PL_EGNN(nn.Module):
|
|
def __init__(self, model_type=None):
|
|
if model_type == cz_label:
|
|
out_node_nf = 2
|
|
else:
|
|
out_node_nf = 3
|
|
super(PL_EGNN, self).__init__()
|
|
self.model = EGNN(
|
|
in_node_nf = 9,
|
|
hidden_nf = 16,
|
|
out_node_nf = out_node_nf,
|
|
n_layers = 2,
|
|
attention = True,
|
|
)
|
|
|
|
def forward(self, h, x, edges, return_features=False):
|
|
h, x = self.model(h, x, edges, return_features=return_features)
|
|
|
|
return h |