Source code for egc.module.layers.disc_communitygan

"""
Discriminator Layer
Adapted from: https://github.com/SamJia/CommunityGAN
"""
import random
from typing import Callable
from typing import List
from typing import Tuple

import numpy as np
import torch
from torch import nn


[docs]class DiscComGAN(nn.Module): """Discriminator of CommunityGAN Args: n_nodes (int): num of nodes. node_emd_init (torch.Tensor): node embedding in agm format pretrained in advance. n_epochs (int): num of training epochs. dis_interval (int): interval for discriminator. update_ratio (float): update ratio. n_sample_dis (int): num of samples for discriminator. lr_dis (float): learning rate. l2_coef (float): l2 coef. batch_size (int): batch size max_value (int): max value for embedding matrix. """ def __init__( self, n_nodes: int, node_emd_init: torch.Tensor, n_epochs: int, dis_interval: int, update_ratio: float, n_sample_dis: int, lr_dis: float, l2_coef: float, batch_size: int, max_value: int, ): super().__init__() self.n_nodes = n_nodes self.n_epochs = n_epochs self.dis_interval = dis_interval self.update_ratio = update_ratio self.n_sample_dis = n_sample_dis self.batch_size = batch_size self.max_value = max_value self.embedding_matrix = nn.Parameter(torch.FloatTensor(node_emd_init), requires_grad=True) self.embedding_matrix.data = torch.FloatTensor(node_emd_init) self.optimizer = torch.optim.Adam(self.parameters(), lr=lr_dis, weight_decay=l2_coef)
[docs] def prepare_data_for_d( self, sampling: Callable, id2motifs: List[List[Tuple]]) -> Tuple[List[Tuple], List[List]]: """generate positive and negative samples for the discriminator Args: sampling (Callable): sampling function. id2motifs (List[List[Tuple]]): list of motifs indexed by node id. Returns: Tuple[List[Tuple], List[List]]: (list of motifs sampled, list of labels) """ motifs = [] labels = [] g_s_args = [] poss = [] negs = [] for i in range(self.n_nodes): if np.random.rand() < self.update_ratio: pos = random.sample(id2motifs[i], min(len(id2motifs[i]), self.n_sample_dis)) poss.append(pos) g_s_args.append((i, len(pos), True)) negs, _ = sampling(g_s_args) for pos, neg in zip(poss, negs): if len(pos) != 0 and neg is not None: motifs.extend(pos) labels.extend([1] * len(pos)) motifs.extend(neg) labels.extend([0] * len(neg)) motifs_idx = list(range(len(motifs))) np.random.shuffle(motifs_idx) motifs = [motifs[i] for i in motifs_idx] labels = [labels[i] for i in motifs_idx] return motifs, labels
[docs] def forward(self, motifs: List[Tuple], label: List[List] = None) -> Tuple[torch.Tensor, torch.Tensor]: """forward Args: motifs (List[Tuple]): motifs label (List[List], optional): labels. Defaults to None. Returns: Tuple[torch.Tensor,torch.Tensor]: (loss, reward) """ score = torch.sum(torch.prod(self.embedding_matrix[motifs], dim=1), dim=1) p = torch.clip(1 - torch.exp(-score), 1e-5, 1) reward = 1 - p loss = (-torch.sum(label * p + (1 - label) * (1 - p)) if label is not None else None) return loss, reward
[docs] def get_reward(self, motifs: List[Tuple], label: List[List] = None) -> np.ndarray: """get reward Args: motifs (List[Tuple]): motifs. label (List[List], optional): labels. Defaults to None. Returns: np.ndarray: reward. """ _, reward = self.forward(motifs, label) return reward.detach().numpy()
[docs] def fit(self, sampling: Callable, id2motifs: List[List[Tuple]]) -> None: """fit Args: sampling (Callable): sampling funciton. id2motifs (List[List[Tuple]]): list of motifs indexed by node id. """ motifs = [] labels = [] for epoch in range(self.n_epochs): self.train() if epoch % self.dis_interval == 0: motifs, labels = self.prepare_data_for_d(sampling, id2motifs) train_size = len(motifs) start_list = list(range(0, train_size, self.batch_size)) np.random.shuffle(start_list) for start in start_list: self.zero_grad() end = start + self.batch_size loss, _ = self.forward(torch.LongTensor(motifs[start:end]), torch.Tensor(labels[start:end])) loss.backward() self.optimizer.step() self.embedding_matrix.data = torch.clip( self.embedding_matrix.data, 0, self.max_value) print(f"discriminator epoch {epoch} loss {loss}")
[docs] def get_embedding(self) -> torch.Tensor: """Get the embeddings (graph or node level). Returns: (torch.Tensor): embedding. """ return self.embedding_matrix.detach()