Source code for egc.module.layers.gcn

"""
GCN Layer
Adapted from: https://github.com/PetarV-/DGI
"""
from typing import Tuple

import torch
from torch import nn

from ...utils import init_weights


[docs]class GCN(nn.Module): """GCN Layer Args: in_feats (int): input feature dimension out_feats (int): output feature dimension activation (str): activation function. Defaults to prelu. bias (bool): whether to apply bias after calculate \\hat{A}XW. Defaults to True. """ def __init__( self, in_feats: int, out_feats: int, activation: str = "prelu", bias: bool = True, ) -> None: super().__init__() self.f_c = nn.Linear(in_feats, out_feats, bias=False) if activation == "prelu": self.activation = nn.PReLU() elif activation == "relu": self.activation = nn.ReLU() else: self.activation = activation if bias: self.bias = nn.Parameter(torch.FloatTensor(out_feats)) self.bias.data.fill_(0.0) else: self.register_parameter("bias", None) for module in self.modules(): init_weights(module)
[docs] def forward(self, features: torch.Tensor, adj_norm: torch.Tensor, sparse: bool = True) -> Tuple[torch.Tensor, torch.Tensor]: """Forward Propagation Args: features (torch.Tensor): normalized 3D features tensor in shape of torch.Size([1, xx, xx]) adj_norm (torch.Tensor): symmetrically normalized 2D adjacency tensor sparse (bool): whether input sparse tensor Returns: out, hidden_layer (torch.Tensor, torch.Tensor): \\hat{A}XW and XW """ hidden_layer = self.f_c(features) if sparse: out = torch.unsqueeze( torch.spmm(adj_norm, torch.squeeze(hidden_layer, 0)), 0) else: out = torch.unsqueeze( torch.bmm(adj_norm, torch.squeeze(hidden_layer, 0)), 0) if self.bias is not None: out = out + self.bias return self.activation(out), hidden_layer