Source code for egc.module.layers.disc_dgi

"""
Discriminator Layer
Adapted from: https://github.com/PetarV-/DGI
"""
import torch
from torch import nn

from ...utils import init_weights


[docs]class DiscDGI(nn.Module): """Discriminator for DGI Args: hidden_units (int): hidden units dimension. Defaults to 512. bias (bool): whether to apply bias to xWy. Defaults to False. """ def __init__(self, hidden_units: int = 512, bias: bool = True) -> None: super().__init__() self.f_k = nn.Bilinear(hidden_units, hidden_units, 1, bias) for module in self.modules(): init_weights(module)
[docs] def forward( self, g: torch.Tensor, h: torch.Tensor, h_shf: torch.Tensor, ) -> torch.Tensor: """Forward Propagation Args: g (torch.Tensor): avg readout of whole graph, 1D tensor. h (torch.Tensor): node embedding. 3D tensor. h_shf (torch.Tensor): shuffled node embedding as \ corrupted graph node embedding. 3D tensor. Returns: (torch.Tensor): concat of pos and neg disc output. """ g_x = g if g.shape == h.shape else torch.unsqueeze(g, 1).expand_as(h) s_c_pos = torch.squeeze(self.f_k(h, g_x), 2) s_c_neg = torch.squeeze(self.f_k(h_shf, g_x), 2) return torch.cat((s_c_pos, s_c_neg), 1)