egc.model.graph_clustering.overlapping package

Submodules

egc.model.graph_clustering.overlapping.communityGAN module

CommunityGAN Adapted from: https://github.com/SamJia/CommunityGAN

class egc.model.graph_clustering.overlapping.communityGAN.CommunityGAN(n_nodes: int, node_emd_init_gen: Tensor, node_emd_init_dis: Tensor, max_value: float = 1000, n_epochs: int = 10, n_epochs_gen: int = 3, n_epochs_dis: int = 3, gen_interval: int = 3, dis_interval: int = 3, update_ratio: float = 1.0, n_sample_gen: int = 5, n_sample_dis: int = 5, lr_gen: float = 0.001, lr_dis: float = 0.001, l2_coef: float = 0.0, batch_size_gen: int = 64, batch_size_dis: int = 64)[source]

Bases: Base, Module

Parameters:
  • n_nodes (int) – num of nodes.

  • node_emd_init_gen (torch.Tensor) – initial node embedding for generator.

  • node_emd_init_dis (torch.Tensor) – initial node embedding for discriminator.

  • max_value (float, optional) – max value of embedding. Defaults to 1000.

  • n_epochs (int, optional) – num of training epochs. Defaults to 10.

  • n_epochs_gen (int, optional) – num of training epochs for generator. Defaults to 3.

  • n_epochs_dis (int, optional) – num of traing epochs for discriminator. Defaults to 3.

  • gen_interval (int, optional) – interval of generator. Defaults to 3.

  • dis_interval (int, optional) – interval of discriminator. Defaults to 3.

  • update_ratio (float, optional) – update ratio. Defaults to 1.0.

  • n_sample_gen (int, optional) – num of samples of generator. Defaults to 5.

  • n_sample_dis (int, optional) – num of samples of discriminator. Defaults to 5.

  • lr_gen (float, optional) – learning rate of generator. Defaults to 1e-3.

  • lr_dis (float, optional) – learning rate of discriminator. Defaults to 1e-3.

  • l2_coef (float, optional) – l2 coef of optimizers. Defaults to 0.0.

  • batch_size_gen (int, optional) – batch size of generator. Defaults to 64.

  • batch_size_dis (int, optional) – batch size of discriminator. Defaults to 64.

sampling(g_s_args: Tuple[int, int, bool]) Tuple[List[Tuple], List[List[int]]][source]
Parameters:

g_s_args (Tuple[int,int,bool]) – args tuple for each sample thread.

Returns:

(motifs_new, path_new)

Return type:

Tuple[List[Tuple], List[List[int]]]

forward(id2motifs: List[List[Tuple]]) None[source]
Parameters:

id2motifs (List[List[Tuple]]) – motif lists indexed by node id.

fit(total_motifs: Set[Tuple], id2motifs: List[List[Tuple]], neighbor_set: Dict, motif_size: int = 3) None[source]
Parameters:
  • total_motifs (Set[Tuple]) – set of motifs.

  • id2motifs (List[List[Tuple]]) – motif lists indexed by node id.

  • neighbor_set (Dict) – neighbor set Dict indexed by node id.

  • motif_size (int, optional) – motif size. Defaults to 3.

get_embedding() Tuple[Tensor, Tensor][source]

Get the embeddings (graph or node level).

Returns:

embedding.

Return type:

(torch.Tensor)

get_memberships() ndarray[source]
get_disjoint_memberships() Tuple[ndarray, ndarray][source]

get disjoint membership

Returns:

generator membership, discriminator membership

Return type:

Tuple[np.ndarray, np.ndarray]

get_overlapping_memberships() Tuple[ndarray, ndarray][source]

get overlapping membership

Returns:

generator membership, discriminator membership

Return type:

Tuple[np.ndarray, np.ndarray]

Module contents

Overlapping Models