Source code for egc.module.layers.gmi

"""
GMI
Adapted From: https://github.com/zpeng27/GMI
"""
from typing import List

import scipy.sparse as sp
import torch
from torch import nn

from ...utils import sparse_mx_to_torch_sparse_tensor
from .disc_gmi import DiscGMI
from .gcn import GCN


[docs]def avg_neighbor(features: torch.Tensor, adj_orig: sp.csr_matrix) -> torch.Tensor: """Aggregate Neighborhood Using Original Adjacency Matrix Args: features (torch.Tensor): 2D row-normalized features. adj_orig (<class 'scipy.sparse.csr.csr_matrix'>): row-avaraged adj. Returns: (torch.Tensor): row-avaraged aggregation of neighborhood. """ adj_orig = sparse_mx_to_torch_sparse_tensor(adj_orig) if torch.cuda.is_available(): adj_orig = adj_orig.cuda() return torch.unsqueeze(torch.spmm(adj_orig, torch.squeeze(features, 0)), 0)
[docs]class GMI(nn.Module): """GMI Args: in_features (int): input feature dimension. hidden_units (int): output hidden units dimension. activation (str): activation of gcn layer. Defaults to prelu. """ def __init__( self, in_features: int, hidden_units: int, gcn_depth: int = 2, activation: str = "prelu", ) -> None: super().__init__() self.gcn_depth = gcn_depth if gcn_depth == 2: self.gcn_1 = GCN(in_features, hidden_units, activation) self.gcn_2 = GCN(hidden_units, hidden_units, activation) elif gcn_depth == 1: self.gcn_1 = GCN(in_features, hidden_units, activation) else: raise ValueError( "Now gcn_depth only supports 1 or 2 layers, otherwise modify the code on you own." ) self.disc_1 = DiscGMI(in_features, hidden_units, activation="sigmoid") self.disc_2 = DiscGMI(hidden_units, hidden_units, activation="sigmoid") self.prelu = nn.PReLU() self.sigmoid = nn.Sigmoid()
[docs] def forward( self, features_norm: torch.Tensor, adj_orig: sp.csr_matrix, adj_norm: torch.Tensor, neg_sample_list: List, ): """Forward Propagation Args: features_norm (torch.Tensor): row-normalized features. adj_orig (sp.csr_matrix): row-avaraged adj. adj_norm (torch.Tensor): symmetrically normalized sparse tensor adj. neg_sample_list (List): list of multiple repeatable shuffle of nodes index list. Returns: mi_pos, mi_neg, local_mi_pos, local_mi_neg, adj_rebuilt (torch.Tensor): D_w(h_i, x_i), D_w(h_i, x'_i), D_w(h_i, x_j), D_w(h_i, x'_j), w_{ij} """ if self.gcn_depth == 1: h_2, h_w = self.gcn_1(features_norm, adj_norm) else: h_1, h_w = self.gcn_1(features_norm, adj_norm) h_2, _ = self.gcn_2(h_1, adj_norm) h_neighbor = self.prelu(avg_neighbor(h_w, adj_orig)) mi_pos, mi_neg = self.disc_1(features_norm, h_2, neg_sample_list) local_mi_pos, local_mi_neg = self.disc_2(h_neighbor, h_2, neg_sample_list) adj_rebuilt = self.sigmoid( torch.mm(torch.squeeze(h_2), torch.t(torch.squeeze(h_2)))) return mi_pos, mi_neg, local_mi_pos, local_mi_neg, adj_rebuilt
[docs] def get_embedding(self, features_norm, adj_norm): """Get Node Embedding Args: features_norm (torch.Tensor): row-normalized features. adj_norm (torch.Tensor): symmetrically normalized adj. Returns: (torch.Tensor): node embedding. """ if self.gcn_depth == 1: h_2, _ = self.gcn_1(features_norm, adj_norm) else: h_1, _ = self.gcn_1(features_norm, adj_norm) h_2, _ = self.gcn_2(h_1, adj_norm) return h_2.detach()