Source code for egc.model.graph_clustering.disjoint.gae_kmeans

"""gae_kmeans
"""
# import torch
import scipy.sparse as sp
import torch

from ....utils import sk_clustering
from ...node_embedding import DGL_GAE
from ..base import Base


[docs]class DGL_GAEKmeans(Base): """GAE Kmeans implement using dgl Args: epochs (int, optional): number of embedding training epochs. Defaults to 200. n_clusters (int): cluster num. fead_dim (int): dim of features n_nodes (int): number of nodes hidden_dim1 (int): hidden units size of gcn_1. Defaults to 32. dropout (int, optional): Dropout rate (1 - keep probability). lr (float, optional): learning rate.. Defaults to 0.001. early_stop (int, optional): early stopping threshold. Defaults to 10. activation (str, optional): activation of gcn layer_1. Defaults to 'relu'. """ def __init__( self, epochs: int, n_clusters: int, fead_dim: int, n_nodes: int, hidden_dim1: int = 32, dropout: float = 0.0, lr: float = 0.01, early_stop: int = 10, activation: str = "relu", ) -> None: super().__init__() self.n_clusters = n_clusters # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') self.model = DGL_GAE( epochs, n_clusters, fead_dim, n_nodes, hidden_dim1, dropout, lr, early_stop, activation, )
[docs] def fit( self, adj_csr: sp.csr_matrix, features: torch.Tensor, device: torch.device = torch.device("cpu"), ): """Fit for Specific Graph Args: adj_csr (sp.lil_matrix): 2D sparse features. features (torch.Tensor): node's features """ self.model.fit(adj_csr, features, device=device)
[docs] def get_embedding(self): return self.model.get_embedding()
[docs] def get_memberships(self): return sk_clustering( self.get_embedding().cpu(), self.n_clusters, name="kmeans", )