54 lines
1.1 KiB
Python
54 lines
1.1 KiB
Python
import time
|
|
import timm
|
|
import torch
|
|
import collections
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from sklearn.manifold import TSNE
|
|
from core.data import AtomDataset
|
|
from torch_geometric.loader import DataLoader
|
|
|
|
import torch.nn as nn
|
|
from core.egnn_clean import EGNN
|
|
|
|
class PL_EGNN(nn.Module):
|
|
def __init__(self):
|
|
super(PL_EGNN, self).__init__()
|
|
self.model = EGNN(
|
|
in_node_nf = 9,
|
|
hidden_nf = 9,
|
|
out_node_nf = 3,
|
|
n_layers = 2,
|
|
attention = True,
|
|
)
|
|
|
|
def forward(self, h, x, edges, return_features=False):
|
|
h, x = self.model(h, x, edges, return_features=return_features)
|
|
|
|
return h
|
|
|
|
|
|
ds = AtomDataset('../../data/gnn_data/test/')
|
|
dl = DataLoader(ds, batch_size=512)
|
|
print(len(ds))
|
|
model = PL_EGNN().cuda().eval()
|
|
|
|
for data in dl:
|
|
light = data.light.cuda()
|
|
pos = data.pos.cuda()
|
|
edge_index = data.edge_index.cuda()
|
|
break
|
|
|
|
nums = 88
|
|
|
|
s = time.time()
|
|
|
|
for i in range(nums):
|
|
with torch.no_grad():
|
|
model(light, pos, edge_index)
|
|
|
|
e = time.time()
|
|
print(e - s)
|