Source code for egc.model.graph_clustering.disjoint.sgc_kmeans

"""sgc_kmeans
"""
from typing import Callable
from typing import List

import dgl
import torch

from ....utils import sk_clustering
from ...node_embedding import SGC
from ..base import Base


[docs]class SGCKmeans(Base): """GAE Kmeans implement using dgl Args: epochs (int, optional): number of embedding training epochs. Defaults to 200. n_clusters (int): cluster num. fead_dim (int): dim of features n_nodes (int): number of nodes hidden_dim1 (int): hidden units size of gcn_1. Defaults to 32. dropout (int, optional): Dropout rate (1 - keep probability). lr (float, optional): learning rate.. Defaults to 0.001. early_stop (int, optional): early stopping threshold. Defaults to 10. activation (str, optional): activation of gcn layer_1. Defaults to 'relu'. """ def __init__( self, in_feats: int, n_epochs: int = 400, hidden_units: List = [500], lr: float = 0.01, early_stop: int = 10, inner_act: Callable = lambda x: x, n_lin_layers: int = 1, n_gnn_layers: int = 10, ) -> None: super().__init__() self.model = SGC( in_feats=in_feats, hidden_units=hidden_units, n_lin_layers=n_lin_layers, n_gnn_layers=n_gnn_layers, lr=lr, n_epochs=n_epochs, inner_act=inner_act, early_stop=early_stop, )
[docs] def fit( self, graph: dgl.DGLGraph, n_clusters: int, device: torch.device = torch.device("cpu"), ): """Fit for Specific Graph Args: graph (dgl.DGLGraph): dgl graph. n_clusters (int): cluster num. device (torch.device, optional): torch device. Defaults to torch.device('cpu'). """ self.n_clusters = n_clusters self.model.fit( graph=graph, device=device, )
[docs] def get_embedding(self): return self.model.get_embedding()
[docs] def get_memberships(self): return sk_clustering( self.get_embedding().cpu(), self.n_clusters, name="kmeans", )