Source code for egc.module.layers.gcl_sublime

"""
Graph Contrastive Learning Model for SUBLIME
"""
import copy

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Sequential

from .gcn_sublime import GCNConv_dense
from .gcn_sublime import GCNConv_dgl

# pylint:disable=no-else-return,protected-access


[docs]class GCL_SUBLIME(nn.Module): """Graph contrastive learning of SUBLIME Args: nlayers (int): Number of gcn layers in_dim (int): Number of input dim hidden_dim (int): Number of hidden dim emb_dim (int): Number of embedding dimension proj_dim (int): Number of projection dimension dropout (float): Dropout rate dropout_adj (float): Drop edge rate sparse (int): If sparse mode """ def __init__( self, nlayers, in_dim, hidden_dim, emb_dim, proj_dim, dropout, dropout_adj, sparse, ): super().__init__() self.encoder = GraphEncoder(nlayers, in_dim, hidden_dim, emb_dim, proj_dim, dropout, dropout_adj, sparse)
[docs] def forward(self, x, Adj_, branch=None): z, embedding = self.encoder(x, Adj_, branch) return z, embedding
[docs] @staticmethod def calc_loss(x, x_aug, temperature=0.2, sym=True): batch_size, _ = x.size() x_abs = x.norm(dim=1) x_aug_abs = x_aug.norm(dim=1) sim_matrix = torch.einsum("ik,jk->ij", x, x_aug) / torch.einsum( "i,j->ij", x_abs, x_aug_abs) sim_matrix = torch.exp(sim_matrix / temperature) pos_sim = sim_matrix[range(batch_size), range(batch_size)] if sym: loss_0 = pos_sim / (sim_matrix.sum(dim=0) - pos_sim) loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) loss_0 = -torch.log(loss_0).mean() loss_1 = -torch.log(loss_1).mean() loss = (loss_0 + loss_1) / 2.0 return loss else: loss_1 = pos_sim / (sim_matrix.sum(dim=1) - pos_sim) loss_1 = -torch.log(loss_1).mean() return loss_1
[docs]class SparseDropout(nn.Module): """Sparse Dropout Args: dprob (float): dprob is ratio of dropout. Defaults to 0.5. """ def __init__(self, dprob=0.5): super().__init__() # dprob is ratio of dropout # convert to keep probability self.kprob = 1 - dprob
[docs] def forward(self, x): mask = ((torch.rand(x._values().size()) + (self.kprob)).floor()).type( torch.bool) rc = x._indices()[:, mask] val = x._values()[mask] * (1.0 / self.kprob) return torch.sparse.FloatTensor(rc, val, x.shape)
[docs]class GraphEncoder(nn.Module): """Graph Encoder of GSL model Args: nlayers (int): Number of gcn layers in_dim (int): Number of input dim hidden_dim (int): Number of hidden dim emb_dim (int): Number of embedding dimension proj_dim (int): Number of projection dimension dropout (float): Dropout rate dropout_adj (float): Drop edge rate sparse (int): If sparse mode """ def __init__( self, nlayers, in_dim, hidden_dim, emb_dim, proj_dim, dropout, dropout_adj, sparse, ): super().__init__() self.dropout = dropout self.dropout_adj_p = dropout_adj self.sparse = sparse self.gnn_encoder_layers = nn.ModuleList() if sparse: self.gnn_encoder_layers.append(GCNConv_dgl(in_dim, hidden_dim)) for _ in range(nlayers - 2): self.gnn_encoder_layers.append( GCNConv_dgl(hidden_dim, hidden_dim)) self.gnn_encoder_layers.append(GCNConv_dgl(hidden_dim, emb_dim)) else: self.gnn_encoder_layers.append(GCNConv_dense(in_dim, hidden_dim)) for _ in range(nlayers - 2): self.gnn_encoder_layers.append( GCNConv_dense(hidden_dim, hidden_dim)) self.gnn_encoder_layers.append(GCNConv_dense(hidden_dim, emb_dim)) if self.sparse: self.dropout_adj = SparseDropout(dprob=dropout_adj) else: self.dropout_adj = nn.Dropout(p=dropout_adj) self.proj_head = Sequential(Linear(emb_dim, proj_dim), ReLU(inplace=True), Linear(proj_dim, proj_dim))
[docs] def forward(self, x, Adj_, branch=None): if self.sparse: if branch == "anchor": Adj = copy.deepcopy(Adj_) else: Adj = Adj_ Adj.edata["w"] = F.dropout(Adj.edata["w"], p=self.dropout_adj_p, training=self.training) else: Adj = self.dropout_adj(Adj_) for conv in self.gnn_encoder_layers[:-1]: x = conv(x, Adj) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) x = self.gnn_encoder_layers[-1](x, Adj) z = self.proj_head(x) return z, x