351 lines
13 KiB
Python
351 lines
13 KiB
Python
from torch import nn
|
|
import torch
|
|
|
|
class MLP(nn.Module):
|
|
""" a simple 4-layer MLP """
|
|
|
|
def __init__(self, nin, nout, nh):
|
|
super().__init__()
|
|
self.net = nn.Sequential(
|
|
nn.Linear(nin, nh),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Linear(nh, nh),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Linear(nh, nh),
|
|
nn.LeakyReLU(0.2),
|
|
nn.Linear(nh, nout),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.net(x)
|
|
|
|
|
|
class GCL_basic(nn.Module):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
|
|
def __init__(self):
|
|
super(GCL_basic, self).__init__()
|
|
|
|
|
|
def edge_model(self, source, target, edge_attr):
|
|
pass
|
|
|
|
def node_model(self, h, edge_index, edge_attr):
|
|
pass
|
|
|
|
def forward(self, x, edge_index, edge_attr=None):
|
|
row, col = edge_index
|
|
edge_feat = self.edge_model(x[row], x[col], edge_attr)
|
|
x = self.node_model(x, edge_index, edge_feat)
|
|
return x, edge_feat
|
|
|
|
|
|
|
|
class GCL(GCL_basic):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
|
|
def __init__(self, input_nf, output_nf, hidden_nf, edges_in_nf=0, act_fn=nn.ReLU(), bias=True, attention=False, t_eq=False, recurrent=True):
|
|
super(GCL, self).__init__()
|
|
self.attention = attention
|
|
self.t_eq=t_eq
|
|
self.recurrent = recurrent
|
|
input_edge_nf = input_nf * 2
|
|
self.edge_mlp = nn.Sequential(
|
|
nn.Linear(input_edge_nf + edges_in_nf, hidden_nf, bias=bias),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, hidden_nf, bias=bias),
|
|
act_fn)
|
|
if self.attention:
|
|
self.att_mlp = nn.Sequential(
|
|
nn.Linear(input_nf, hidden_nf, bias=bias),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, 1, bias=bias),
|
|
nn.Sigmoid())
|
|
|
|
|
|
self.node_mlp = nn.Sequential(
|
|
nn.Linear(hidden_nf + input_nf, hidden_nf, bias=bias),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, output_nf, bias=bias))
|
|
|
|
#if recurrent:
|
|
#self.gru = nn.GRUCell(hidden_nf, hidden_nf)
|
|
|
|
|
|
def edge_model(self, source, target, edge_attr):
|
|
edge_in = torch.cat([source, target], dim=1)
|
|
if edge_attr is not None:
|
|
edge_in = torch.cat([edge_in, edge_attr], dim=1)
|
|
out = self.edge_mlp(edge_in)
|
|
if self.attention:
|
|
att = self.att_mlp(torch.abs(source - target))
|
|
out = out * att
|
|
return out
|
|
|
|
def node_model(self, h, edge_index, edge_attr):
|
|
row, col = edge_index
|
|
agg = unsorted_segment_sum(edge_attr, row, num_segments=h.size(0))
|
|
out = torch.cat([h, agg], dim=1)
|
|
out = self.node_mlp(out)
|
|
if self.recurrent:
|
|
out = out + h
|
|
#out = self.gru(out, h)
|
|
return out
|
|
|
|
|
|
class GCL_rf(GCL_basic):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
|
|
def __init__(self, nf=64, edge_attr_nf=0, reg=0, act_fn=nn.LeakyReLU(0.2), clamp=False):
|
|
super(GCL_rf, self).__init__()
|
|
|
|
self.clamp = clamp
|
|
layer = nn.Linear(nf, 1, bias=False)
|
|
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
|
|
self.phi = nn.Sequential(nn.Linear(edge_attr_nf + 1, nf),
|
|
act_fn,
|
|
layer)
|
|
self.reg = reg
|
|
|
|
def edge_model(self, source, target, edge_attr):
|
|
x_diff = source - target
|
|
radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1)
|
|
e_input = torch.cat([radial, edge_attr], dim=1)
|
|
e_out = self.phi(e_input)
|
|
m_ij = x_diff * e_out
|
|
if self.clamp:
|
|
m_ij = torch.clamp(m_ij, min=-100, max=100)
|
|
return m_ij
|
|
|
|
def node_model(self, x, edge_index, edge_attr):
|
|
row, col = edge_index
|
|
agg = unsorted_segment_mean(edge_attr, row, num_segments=x.size(0))
|
|
x_out = x + agg - x*self.reg
|
|
return x_out
|
|
|
|
|
|
class E_GCL(nn.Module):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
|
|
def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, clamp=False, norm_diff=False, tanh=False):
|
|
super(E_GCL, self).__init__()
|
|
input_edge = input_nf * 2
|
|
self.coords_weight = coords_weight
|
|
self.recurrent = recurrent
|
|
self.attention = attention
|
|
self.norm_diff = norm_diff
|
|
self.tanh = tanh
|
|
edge_coords_nf = 1
|
|
|
|
|
|
self.edge_mlp = nn.Sequential(
|
|
nn.Linear(input_edge + edge_coords_nf + edges_in_d, hidden_nf),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, hidden_nf),
|
|
act_fn)
|
|
|
|
self.node_mlp = nn.Sequential(
|
|
nn.Linear(hidden_nf + input_nf + nodes_att_dim, hidden_nf),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, output_nf))
|
|
|
|
layer = nn.Linear(hidden_nf, 1, bias=False)
|
|
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
|
|
|
|
self.clamp = clamp
|
|
coord_mlp = []
|
|
coord_mlp.append(nn.Linear(hidden_nf, hidden_nf))
|
|
coord_mlp.append(act_fn)
|
|
coord_mlp.append(layer)
|
|
if self.tanh:
|
|
coord_mlp.append(nn.Tanh())
|
|
self.coords_range = nn.Parameter(torch.ones(1))*3
|
|
self.coord_mlp = nn.Sequential(*coord_mlp)
|
|
|
|
|
|
if self.attention:
|
|
self.att_mlp = nn.Sequential(
|
|
nn.Linear(hidden_nf, 1),
|
|
nn.Sigmoid())
|
|
|
|
#if recurrent:
|
|
# self.gru = nn.GRUCell(hidden_nf, hidden_nf)
|
|
|
|
|
|
def edge_model(self, source, target, radial, edge_attr):
|
|
if edge_attr is None: # Unused.
|
|
out = torch.cat([source, target, radial], dim=1)
|
|
else:
|
|
out = torch.cat([source, target, radial, edge_attr], dim=1)
|
|
out = self.edge_mlp(out)
|
|
if self.attention:
|
|
att_val = self.att_mlp(out)
|
|
out = out * att_val
|
|
return out
|
|
|
|
def node_model(self, x, edge_index, edge_attr, node_attr):
|
|
row, col = edge_index
|
|
agg = unsorted_segment_sum(edge_attr, row, num_segments=x.size(0))
|
|
if node_attr is not None:
|
|
agg = torch.cat([x, agg, node_attr], dim=1)
|
|
else:
|
|
agg = torch.cat([x, agg], dim=1)
|
|
out = self.node_mlp(agg)
|
|
if self.recurrent:
|
|
out = x + out
|
|
return out, agg
|
|
|
|
def coord_model(self, coord, edge_index, coord_diff, edge_feat):
|
|
row, col = edge_index
|
|
trans = coord_diff * self.coord_mlp(edge_feat)
|
|
trans = torch.clamp(trans, min=-100, max=100) #This is never activated but just in case it case it explosed it may save the train
|
|
agg = unsorted_segment_mean(trans, row, num_segments=coord.size(0))
|
|
coord += agg*self.coords_weight
|
|
return coord
|
|
|
|
|
|
def coord2radial(self, edge_index, coord):
|
|
row, col = edge_index
|
|
coord_diff = coord[row] - coord[col]
|
|
radial = torch.sum((coord_diff)**2, 1).unsqueeze(1)
|
|
|
|
if self.norm_diff:
|
|
norm = torch.sqrt(radial) + 1
|
|
coord_diff = coord_diff/(norm)
|
|
|
|
return radial, coord_diff
|
|
|
|
def forward(self, h, edge_index, coord, edge_attr=None, node_attr=None):
|
|
row, col = edge_index
|
|
radial, coord_diff = self.coord2radial(edge_index, coord)
|
|
|
|
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
|
|
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
|
|
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
|
|
# coord = self.node_coord_model(h, coord)
|
|
# x = self.node_model(x, edge_index, x[col], u, batch) # GCN
|
|
return h, coord, edge_attr
|
|
|
|
|
|
class E_GCL_vel(E_GCL):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
|
|
|
|
def __init__(self, input_nf, output_nf, hidden_nf, edges_in_d=0, nodes_att_dim=0, act_fn=nn.ReLU(), recurrent=True, coords_weight=1.0, attention=False, norm_diff=False, tanh=False):
|
|
E_GCL.__init__(self, input_nf, output_nf, hidden_nf, edges_in_d=edges_in_d, nodes_att_dim=nodes_att_dim, act_fn=act_fn, recurrent=recurrent, coords_weight=coords_weight, attention=attention, norm_diff=norm_diff, tanh=tanh)
|
|
self.norm_diff = norm_diff
|
|
self.coord_mlp_vel = nn.Sequential(
|
|
nn.Linear(input_nf, hidden_nf),
|
|
act_fn,
|
|
nn.Linear(hidden_nf, 1))
|
|
|
|
def forward(self, h, edge_index, coord, vel, edge_attr=None, node_attr=None):
|
|
row, col = edge_index
|
|
radial, coord_diff = self.coord2radial(edge_index, coord)
|
|
|
|
edge_feat = self.edge_model(h[row], h[col], radial, edge_attr)
|
|
coord = self.coord_model(coord, edge_index, coord_diff, edge_feat)
|
|
|
|
|
|
coord += self.coord_mlp_vel(h) * vel
|
|
h, agg = self.node_model(h, edge_index, edge_feat, node_attr)
|
|
# coord = self.node_coord_model(h, coord)
|
|
# x = self.node_model(x, edge_index, x[col], u, batch) # GCN
|
|
return h, coord, edge_attr
|
|
|
|
|
|
|
|
|
|
class GCL_rf_vel(nn.Module):
|
|
"""Graph Neural Net with global state and fixed number of nodes per graph.
|
|
Args:
|
|
hidden_dim: Number of hidden units.
|
|
num_nodes: Maximum number of nodes (for self-attentive pooling).
|
|
global_agg: Global aggregation function ('attn' or 'sum').
|
|
temp: Softmax temperature.
|
|
"""
|
|
def __init__(self, nf=64, edge_attr_nf=0, act_fn=nn.LeakyReLU(0.2), coords_weight=1.0):
|
|
super(GCL_rf_vel, self).__init__()
|
|
self.coords_weight = coords_weight
|
|
self.coord_mlp_vel = nn.Sequential(
|
|
nn.Linear(1, nf),
|
|
act_fn,
|
|
nn.Linear(nf, 1))
|
|
|
|
layer = nn.Linear(nf, 1, bias=False)
|
|
torch.nn.init.xavier_uniform_(layer.weight, gain=0.001)
|
|
#layer.weight.uniform_(-0.1, 0.1)
|
|
self.phi = nn.Sequential(nn.Linear(1 + edge_attr_nf, nf),
|
|
act_fn,
|
|
layer,
|
|
nn.Tanh()) #we had to add the tanh to keep this method stable
|
|
|
|
def forward(self, x, vel_norm, vel, edge_index, edge_attr=None):
|
|
row, col = edge_index
|
|
edge_m = self.edge_model(x[row], x[col], edge_attr)
|
|
x = self.node_model(x, edge_index, edge_m)
|
|
x += vel * self.coord_mlp_vel(vel_norm)
|
|
return x, edge_attr
|
|
|
|
def edge_model(self, source, target, edge_attr):
|
|
x_diff = source - target
|
|
radial = torch.sqrt(torch.sum(x_diff ** 2, dim=1)).unsqueeze(1)
|
|
e_input = torch.cat([radial, edge_attr], dim=1)
|
|
e_out = self.phi(e_input)
|
|
m_ij = x_diff * e_out
|
|
return m_ij
|
|
|
|
def node_model(self, x, edge_index, edge_m):
|
|
row, col = edge_index
|
|
agg = unsorted_segment_mean(edge_m, row, num_segments=x.size(0))
|
|
x_out = x + agg * self.coords_weight
|
|
return x_out
|
|
|
|
|
|
def unsorted_segment_sum(data, segment_ids, num_segments):
|
|
"""Custom PyTorch op to replicate TensorFlow's `unsorted_segment_sum`."""
|
|
result_shape = (num_segments, data.size(1))
|
|
result = data.new_full(result_shape, 0) # Init empty result tensor.
|
|
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
|
|
result.scatter_add_(0, segment_ids, data)
|
|
return result
|
|
|
|
|
|
def unsorted_segment_mean(data, segment_ids, num_segments):
|
|
result_shape = (num_segments, data.size(1))
|
|
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
|
|
result = data.new_full(result_shape, 0) # Init empty result tensor.
|
|
count = data.new_full(result_shape, 0)
|
|
result.scatter_add_(0, segment_ids, data)
|
|
count.scatter_add_(0, segment_ids, torch.ones_like(data))
|
|
return result / count.clamp(min=1) |