atom-predict/egnn_v2/.ipynb_checkpoints/test_gnn-checkpoint.py

54 lines
1.1 KiB
Python

import time
import timm
import torch
import collections
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from core.data import AtomDataset
from torch_geometric.loader import DataLoader
import torch.nn as nn
from core.egnn_clean import EGNN
class PL_EGNN(nn.Module):
def __init__(self):
super(PL_EGNN, self).__init__()
self.model = EGNN(
in_node_nf = 9,
hidden_nf = 9,
out_node_nf = 3,
n_layers = 2,
attention = True,
)
def forward(self, h, x, edges, return_features=False):
h, x = self.model(h, x, edges, return_features=return_features)
return h
ds = AtomDataset('../../data/gnn_data/test/')
dl = DataLoader(ds, batch_size=512)
print(len(ds))
model = PL_EGNN().cuda().eval()
for data in dl:
light = data.light.cuda()
pos = data.pos.cuda()
edge_index = data.edge_index.cuda()
break
nums = 88
s = time.time()
for i in range(nums):
with torch.no_grad():
model(light, pos, edge_index)
e = time.time()
print(e - s)