egc.model.graph_clustering.overlapping package
Submodules
egc.model.graph_clustering.overlapping.communityGAN module
CommunityGAN Adapted from: https://github.com/SamJia/CommunityGAN
- class egc.model.graph_clustering.overlapping.communityGAN.CommunityGAN(n_nodes: int, node_emd_init_gen: Tensor, node_emd_init_dis: Tensor, max_value: float = 1000, n_epochs: int = 10, n_epochs_gen: int = 3, n_epochs_dis: int = 3, gen_interval: int = 3, dis_interval: int = 3, update_ratio: float = 1.0, n_sample_gen: int = 5, n_sample_dis: int = 5, lr_gen: float = 0.001, lr_dis: float = 0.001, l2_coef: float = 0.0, batch_size_gen: int = 64, batch_size_dis: int = 64)[source]
Bases:
Base,Module- Parameters:
n_nodes (int) – num of nodes.
node_emd_init_gen (torch.Tensor) – initial node embedding for generator.
node_emd_init_dis (torch.Tensor) – initial node embedding for discriminator.
max_value (float, optional) – max value of embedding. Defaults to 1000.
n_epochs (int, optional) – num of training epochs. Defaults to 10.
n_epochs_gen (int, optional) – num of training epochs for generator. Defaults to 3.
n_epochs_dis (int, optional) – num of traing epochs for discriminator. Defaults to 3.
gen_interval (int, optional) – interval of generator. Defaults to 3.
dis_interval (int, optional) – interval of discriminator. Defaults to 3.
update_ratio (float, optional) – update ratio. Defaults to 1.0.
n_sample_gen (int, optional) – num of samples of generator. Defaults to 5.
n_sample_dis (int, optional) – num of samples of discriminator. Defaults to 5.
lr_gen (float, optional) – learning rate of generator. Defaults to 1e-3.
lr_dis (float, optional) – learning rate of discriminator. Defaults to 1e-3.
l2_coef (float, optional) – l2 coef of optimizers. Defaults to 0.0.
batch_size_gen (int, optional) – batch size of generator. Defaults to 64.
batch_size_dis (int, optional) – batch size of discriminator. Defaults to 64.
- sampling(g_s_args: Tuple[int, int, bool]) Tuple[List[Tuple], List[List[int]]][source]
- Parameters:
g_s_args (Tuple[int,int,bool]) – args tuple for each sample thread.
- Returns:
(motifs_new, path_new)
- Return type:
Tuple[List[Tuple], List[List[int]]]
- forward(id2motifs: List[List[Tuple]]) None[source]
- Parameters:
id2motifs (List[List[Tuple]]) – motif lists indexed by node id.
- fit(total_motifs: Set[Tuple], id2motifs: List[List[Tuple]], neighbor_set: Dict, motif_size: int = 3) None[source]
- Parameters:
total_motifs (Set[Tuple]) – set of motifs.
id2motifs (List[List[Tuple]]) – motif lists indexed by node id.
neighbor_set (Dict) – neighbor set Dict indexed by node id.
motif_size (int, optional) – motif size. Defaults to 3.
- get_embedding() Tuple[Tensor, Tensor][source]
Get the embeddings (graph or node level).
- Returns:
embedding.
- Return type:
(torch.Tensor)
Module contents
Overlapping Models