Source code for egc.module.layers.grace_secomm

"""
GraceModel for SEComm
"""
# pylint:disable=W0223
import dgl
import torch
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GraphConv
from torch import nn


[docs]class SECommEncoder(torch.nn.Module): """SECommEncoder, k层GCN""" def __init__( self, in_channels: int, out_channels: int, activation, base_model=GraphConv, k: int = 2, ): super().__init__() self.base_model = base_model assert k >= 2 self.k = k self.conv = [ base_model( in_channels, 2 * out_channels, activation=activation, ) ] for _ in range(1, k - 1): self.conv.append( base_model( 2 * out_channels, 2 * out_channels, activation=activation, )) self.conv.append( base_model( 2 * out_channels, out_channels, activation=activation, )) self.conv = nn.ModuleList(self.conv)
[docs] def forward(self, g: dgl.DGLGraph, feats: torch.Tensor): g = dgl.add_self_loop(g) for i in range(self.k): x = self.conv[i](g, feats) feats = x return x
[docs]class SECommGraceModel(torch.nn.Module): """GraceModel for SEComm""" def __init__( self, encoder: SECommEncoder, num_hidden: int, num_proj_hidden: int, tau: float = 0.5, ): super().__init__() self.encoder: SECommEncoder = encoder self.tau: float = tau self.fc1 = torch.nn.Linear(num_hidden, num_proj_hidden) self.fc2 = torch.nn.Linear(num_proj_hidden, num_hidden)
[docs] def forward(self, g: dgl.DGLGraph, feats: torch.Tensor) -> torch.Tensor: return self.encoder(g, feats)
[docs] def projection(self, z: torch.Tensor) -> torch.Tensor: z = F.elu(self.fc1(z)) return self.fc2(z)
# pylint:disable=R0201
[docs] def sim(self, z1: torch.Tensor, z2: torch.Tensor): z1 = F.normalize(z1) z2 = F.normalize(z2) return torch.mm(z1, z2.t())
[docs] def semi_loss(self, z1: torch.Tensor, z2: torch.Tensor): f = lambda x: torch.exp(x / self.tau) refl_sim = f(self.sim(z1, z1)) between_sim = f(self.sim(z1, z2)) return -torch.log( between_sim.diag() / (refl_sim.sum(1) + between_sim.sum(1) - refl_sim.diag()))
[docs] def batched_semi_loss( self, z1: torch.Tensor, z2: torch.Tensor, batch_size: int, ): # Space complexity: O(BN) (semi_loss: O(N^2)) device = z1.device num_nodes = z1.size(0) num_batches = (num_nodes - 1) // batch_size + 1 f = lambda x: torch.exp(x / self.tau) indices = torch.arange(0, num_nodes).to(device) rand_indices = torch.randperm(num_nodes).to(device) losses = [] # for i in range(num_batches): # mask = indices[i * batch_size:(i + 1) * batch_size] # refl_sim = f(self.sim(z1[mask], z1)) # [B, N] # between_sim = f(self.sim(z1[mask], z2)) # [B, N] # losses.append(-torch.log( # between_sim[:, i * batch_size:(i + 1) * batch_size].diag() # / (refl_sim.sum(1) + between_sim.sum(1) # - refl_sim[:, i * batch_size:(i + 1) * batch_size].diag()))) for i in range(num_batches): ordered_mask = indices[i * batch_size:(i + 1) * batch_size] random_mask = rand_indices[i * batch_size:(i + 1) * batch_size] refl_sim = f(self.sim(z1[ordered_mask], z1[random_mask])) # [B, N] between_sim = f(self.sim( z1[ordered_mask], z2[random_mask], )) # [B, N] # losses.append(-torch.log( # f((F.normalize(z1[ordered_mask])*F.normalize(z2[ordered_mask])).sum(1)) # / (refl_sim.sum(1) + between_sim.sum(1)))) losses.append( torch.log(refl_sim.sum(1) + between_sim.sum(1)) - (F.normalize(z1[ordered_mask]) * F.normalize(z2[ordered_mask])).sum(1) / self.tau) return torch.cat(losses)
[docs] def loss( self, z1: torch.Tensor, z2: torch.Tensor, mean: bool = True, batch_size: int = 0, ): h1 = self.projection(z1) h2 = self.projection(z2) if batch_size == 0: l1 = self.semi_loss(h1, h2) l2 = self.semi_loss(h2, h1) else: l1 = self.batched_semi_loss(h1, h2, batch_size) l2 = self.batched_semi_loss(h2, h1, batch_size) ret = (l1 + l2) * 0.5 ret = ret.mean() if mean else ret.sum() return ret