Source code for egc.module.layers.gat_daegc

"""
GAT for DAEGC
"""
# pylint:disable=no-self-use
import torch
import torch.nn.functional as F
from torch import nn


[docs]class GAT(nn.Module): """GAT for DAEGC Args: num_features (int): input feature dimension. hidden_size (int): number of units in hiddin layer. embedding_size (int): number of output emb dim. alpha (float): Alpha for the leaky_relu. """ def __init__(self, num_features, hidden_size, embedding_size, alpha): super().__init__() self.hidden_size = hidden_size self.embedding_size = embedding_size self.alpha = alpha self.conv1 = GATLayer(num_features, hidden_size, alpha) self.conv2 = GATLayer(hidden_size, embedding_size, alpha)
[docs] def forward(self, x, adj, M): """Forward Propagation Args: x (torch.Tensor): features of nodes adj (torch.Tensor): adj matrix M (torch.Tensor): the topological relevance of node j to node i up to t orders. Returns: A_pred (torch.Tensor): Reconstructed adj matrix z (torch.Tensor): latent representation """ h = self.conv1(x, adj, M) h = self.conv2(h, adj, M) z = F.normalize(h, p=2, dim=1) A_pred = self.dot_product_decode(z) return A_pred, z
[docs] def dot_product_decode(self, Z): """dot product decode Args: Z (torch.Tensor): node embedding. Returns: torch.Tensor: Reconstructed adj matrix """ A_pred = torch.sigmoid(torch.matmul(Z, Z.t())) return A_pred
[docs]class GATLayer(nn.Module): """Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 Args: in_features (int): dim num of input out_features (int): dim num of output alpha (float): Alpha for the leaky_relu. """ def __init__(self, in_features, out_features, alpha=0.2): super().__init__() self.in_features = in_features self.out_features = out_features self.alpha = alpha self.W = nn.Parameter(torch.zeros(size=(in_features, out_features))) nn.init.xavier_uniform_(self.W.data, gain=1.414) self.a_self = nn.Parameter(torch.zeros(size=(out_features, 1))) nn.init.xavier_uniform_(self.a_self.data, gain=1.414) self.a_neighs = nn.Parameter(torch.zeros(size=(out_features, 1))) nn.init.xavier_uniform_(self.a_neighs.data, gain=1.414) self.leakyrelu = nn.LeakyReLU(self.alpha) # pylint:disable=no-else-return
[docs] def forward(self, x, adj, M, concat=True): """Forward Propagation Args: x (torch.Tensor): features of nodes adj (torch.Tensor): adj matrix M (torch.Tensor): the topological relevance of node j to node i up to t orders. concat (bool,optional):if concat Returns: (torch.Tensor): latent representation """ h = torch.mm(x, self.W) attn_for_self = torch.mm(h, self.a_self) # (N,1) attn_for_neighs = torch.mm(h, self.a_neighs) # (N,1) attn_dense = attn_for_self + torch.transpose(attn_for_neighs, 0, 1) attn_dense = torch.mul(attn_dense, M) attn_dense = self.leakyrelu(attn_dense) # (N,N) zero_vec = -9e15 * torch.ones_like(adj) adj = torch.where(adj > 0, attn_dense, zero_vec) attention = F.softmax(adj, dim=1) h_prime = torch.matmul(attention, h) if concat: return F.elu(h_prime) else: return h_prime
def __repr__(self): return (self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")")