Source code for egc.module.layers.disc_gmi

"""
Discriminator Layer
Adapted from: https://github.com/PetarV-/DGI & https://github.com/zpeng27/GMI
"""
import torch
from torch import nn

from ...utils import init_weights


[docs]class DiscGMI(nn.Module): """Discriminator Layer Args: in1_features (int): size of each first input sample. in2_features (int): size of each second input sample. out_features (int): size of each output sample. Defaults to 1. activation (str): activation of xWy. Defaults to sigmoid. bias (bool): whether to apply bias to xWy. Defaults to False. """ def __init__( self, in1_features: int, in2_features: int, out_features: int = 1, activation: str = "sigmoid", bias: bool = True, ) -> None: super().__init__() self.f_k = nn.Bilinear(in1_features, in2_features, out_features, bias) if activation == "sigmoid": self.activation = nn.Sigmoid() else: self.activation = nn.ReLU() for module in self.modules(): init_weights(module)
[docs] def forward( self, in1_features: torch.Tensor, in2_features: torch.Tensor, neg_sample_list: int = None, ): """Forward Propagation Args: in1_features (torch.Tensor): first input sample in shape of [1, xx, xx]. in2_features (torch.Tensor): second input sample in shape of [1, xx, xx]. neg_sample_list (List, optional): list of neg sampling nodes index of first input. Defaults to None. Returns: s_c, s_c_neg (torch.Tensor): output of discriminator. """ s_c = self.activation( torch.squeeze(self.f_k(in1_features, in2_features), 2)) if neg_sample_list is not None: s_c_neg = self._neg_sample_forward(in1_features, in2_features, neg_sample_list) else: s_c_neg = None return s_c, s_c_neg
def _neg_sample_forward(self, in1_features, in2_features, neg_sample_list): """Pointwise Forward Propagation Args: in1_features (torch.Tensor): first input sample in shape of [1, xx, xx]. in2_features (torch.Tensor): second input sample in shape of [1, xx, xx]. neg_sample_list (List): list of negative sampling nodes index of first input. Returns: s_c_neg (torch.Tensor): output of discriminator for negative sampling list . """ s_c_list = [] list_len = len(neg_sample_list) for i in range(list_len): h_mi = torch.unsqueeze(in1_features[0][neg_sample_list[i]], 0) s_c_iter = torch.squeeze(self.f_k(h_mi, in2_features), 2) s_c_list.append(s_c_iter) sc_stack = torch.squeeze(torch.stack(s_c_list, 1), 0) return self.activation(sc_stack)