Source code for egc.model.graph_clustering.overlapping.communityGAN

"""
CommunityGAN
Adapted from: https://github.com/SamJia/CommunityGAN
"""
import math
from typing import Dict
from typing import List
from typing import Set
from typing import Tuple

import numpy as np
import torch
from torch import nn

from ....module import DiscComGAN
from ....module import GeneComGAN
from ....utils import CommunityGANSampling
from ..base import Base


[docs]class CommunityGAN(Base, nn.Module): """CommunityGAN Args: n_nodes (int): num of nodes. node_emd_init_gen (torch.Tensor): initial node embedding for generator. node_emd_init_dis (torch.Tensor): initial node embedding for discriminator. max_value (float, optional): max value of embedding. Defaults to 1000. n_epochs (int, optional): num of training epochs. Defaults to 10. n_epochs_gen (int, optional): num of training epochs for generator. Defaults to 3. n_epochs_dis (int, optional): num of traing epochs for discriminator. Defaults to 3. gen_interval (int, optional): interval of generator. Defaults to 3. dis_interval (int, optional): interval of discriminator. Defaults to 3. update_ratio (float, optional): update ratio. Defaults to 1.0. n_sample_gen (int, optional): num of samples of generator. Defaults to 5. n_sample_dis (int, optional): num of samples of discriminator. Defaults to 5. lr_gen (float, optional): learning rate of generator. Defaults to 1e-3. lr_dis (float, optional): learning rate of discriminator. Defaults to 1e-3. l2_coef (float, optional): l2 coef of optimizers. Defaults to 0.0. batch_size_gen (int, optional): batch size of generator. Defaults to 64. batch_size_dis (int, optional): batch size of discriminator. Defaults to 64. """ def __init__( self, n_nodes: int, node_emd_init_gen: torch.Tensor, node_emd_init_dis: torch.Tensor, max_value: float = 1000, n_epochs: int = 10, n_epochs_gen: int = 3, n_epochs_dis: int = 3, gen_interval: int = 3, dis_interval: int = 3, update_ratio: float = 1.0, n_sample_gen: int = 5, n_sample_dis: int = 5, lr_gen: float = 1e-3, lr_dis: float = 1e-3, l2_coef: float = 0.0, batch_size_gen: int = 64, batch_size_dis: int = 64, ) -> None: super().__init__() nn.Module.__init__(self) self.n_epochs = n_epochs self.generator = GeneComGAN( n_nodes=n_nodes, node_emd_init=node_emd_init_gen, n_epochs=n_epochs_gen, gen_interval=gen_interval, update_ratio=update_ratio, n_sample_gen=n_sample_gen, lr_gen=lr_gen, l2_coef=l2_coef, batch_size=batch_size_gen, max_value=max_value, ) self.discriminator = DiscComGAN( n_nodes=n_nodes, node_emd_init=node_emd_init_dis, n_epochs=n_epochs_dis, dis_interval=dis_interval, update_ratio=update_ratio, n_sample_dis=n_sample_dis, lr_dis=lr_dis, l2_coef=l2_coef, batch_size=batch_size_dis, max_value=max_value, ) self.motif_size = None self.total_motifs = None self.neighbor_set = None
[docs] def sampling( self, g_s_args: Tuple[int, int, bool]) -> Tuple[List[Tuple], List[List[int]]]: """sampling Args: g_s_args (Tuple[int,int,bool]): args tuple for each sample thread. Returns: Tuple[List[Tuple], List[List[int]]]: (motifs_new, path_new) """ sampler = CommunityGANSampling( 16, g_s_args, self.motif_size, self.total_motifs, self.generator.get_embedding().numpy(), self.neighbor_set, ) motifs_new, path_new = sampler.run() return motifs_new, path_new
[docs] def forward(self, id2motifs: List[List[Tuple]]) -> None: """forward Args: id2motifs (List[List[Tuple]]): motif lists indexed by node id. """ self.discriminator.fit(self.sampling, id2motifs) self.generator.fit(self.discriminator.get_reward, self.sampling)
[docs] def fit( self, total_motifs: Set[Tuple], id2motifs: List[List[Tuple]], neighbor_set: Dict, motif_size: int = 3, ) -> None: """fit Args: total_motifs (Set[Tuple]): set of motifs. id2motifs (List[List[Tuple]]): motif lists indexed by node id. neighbor_set (Dict): neighbor set Dict indexed by node id. motif_size (int, optional): motif size. Defaults to 3. """ self.motif_size = motif_size self.total_motifs = total_motifs self.neighbor_set = neighbor_set # # for test only # ce = CommunityEval(community_filename, ground_truth_m) # result = ce.eval_community(self.generator.get_embedding()) # print("gen:" + str(result) + "\n") for epoch in range(self.n_epochs): print(f"epoch {epoch}") self.forward(id2motifs)
# # for test only # ce = CommunityEval(community_filename, ground_truth_m) # result = ce.eval_community(self.generator.get_embedding()) # print("gen:" + str(result) + "\n")
[docs] def get_embedding(self) -> Tuple[torch.Tensor, torch.Tensor]: """Get the embeddings (graph or node level). Returns: (torch.Tensor): embedding. """ return (self.generator.get_embedding(), self.discriminator.get_embedding())
[docs] def get_memberships(self) -> np.ndarray: pred, _ = self.get_disjoint_memberships() return pred
[docs] def get_disjoint_memberships(self) -> Tuple[np.ndarray, np.ndarray]: """get disjoint membership Returns: Tuple[np.ndarray, np.ndarray]: generator membership, discriminator membership """ return ( torch.argmax(self.generator.get_embedding(), dim=1).cpu().numpy(), torch.argmax(self.discriminator.get_embedding(), dim=1).cpu().numpy(), )
[docs] def get_overlapping_memberships(self) -> Tuple[np.ndarray, np.ndarray]: """get overlapping membership Returns: Tuple[np.ndarray, np.ndarray]: generator membership, discriminator membership """ # ref to BIGCLAM epsilon = 1e-8 threshold = math.sqrt(-math.log(1 - epsilon)) return ( self.generator.get_embedding() > threshold, self.discriminator.get_embedding() > threshold, )