Source code for egc.model.graph_clustering.disjoint.dgi_kmeans

"""DGI + Kmeans Graph Clustering
"""
from typing import List

import dgl
import torch

from ....utils import sk_clustering
from ...node_embedding import DGIEmbed
from ..base import Base


[docs]class DGIKmeans(Base): """DGI + Kmeans Args: in_feats (int): input feature dimension. out_feats_list (List[int]): List of hidden units dimensions. n_epochs (int, optional): number of embedding training epochs. Defaults to 10000. early_stopping_epoch (int, optional): early stopping threshold. Defaults to 20. batch_size (int, optional): batch size. Defaults to 1024. neighbor_sampler_fanouts (List[int] or int, optional): List of neighbors to sample for each GNN layer, with the i-th element being the fanout for the i-th GNN layer. Defaults to -1. - If only a single integer is provided, DGL assumes that every layer will have the same fanout. - If -1 is provided on one layer, then all inbound edges will be included. lr (float, optional): learning rate. Defaults to 0.001. l2_coef (float, optional): weight decay. Defaults to 0.0. activation (str): activation of gcn layer. Defaults to prelu. model_filename (str, optional): path to save best model parameters. Defaults to `dgi`. """ def __init__( self, in_feats: int, out_feats_list: List[int], n_epochs: int = 10000, early_stopping_epoch: int = 20, batch_size: int = 1024, neighbor_sampler_fanouts: List[int] or int = -1, lr: float = 0.001, l2_coef: float = 0.0, activation: str = "prelu", model_filename: str = "dgi", ) -> None: super().__init__() self.n_clusters = None self.model = DGIEmbed( in_feats=in_feats, out_feats_list=out_feats_list, n_epochs=n_epochs, early_stopping_epoch=early_stopping_epoch, batch_size=batch_size, neighbor_sampler_fanouts=neighbor_sampler_fanouts, lr=lr, l2_coef=l2_coef, activation=activation, model_filename=model_filename, )
[docs] def fit( self, graph: dgl.DGLGraph, n_clusters: int, device: torch.device = torch.device("cpu"), ): """Fit for Specific Graph Args: graph (dgl.DGLGraph): dgl graph. n_clusters (int): cluster num. device (torch.device, optional): torch device. Defaults to torch.device('cpu'). """ self.n_clusters = n_clusters self.model.fit( graph=graph, device=device, )
[docs] def get_embedding( self, graph: dgl.DGLGraph, device: torch.device = torch.device("cpu"), model_filename: str = None, ) -> torch.Tensor: """Get the embeddings. Args: graph (dgl.DGLGraph): dgl graph. device (torch.device, optional): torch device. Defaults to torch.device('cpu'). model_filename (str, optional): Model file to load. Defaults to None. Returns: torch.Tensor: Embeddings. """ return self.model.get_embedding(graph, device, model_filename)
[docs] def get_memberships( self, graph: dgl.DGLGraph, device: torch.device = torch.device("cpu"), model_filename: str = None, ) -> torch.Tensor: """Get the memberships. Args: graph (dgl.DGLGraph): dgl graph. device (torch.device, optional): torch device. Defaults to torch.device('cpu'). model_filename (str, optional): Model file to load. Defaults to None. Returns: torch.Tensor: Embeddings. """ return sk_clustering( torch.squeeze( self.get_embedding( graph, device, model_filename, ), 0, ).cpu(), self.n_clusters, name="kmeans", )
# for test only # if __name__ == '__main__': # from utils import load_data # from utils.evaluation import evaluation # from utils import set_device # from utils import set_seed # import scipy.sparse as sp # import time # set_seed(4096) # device = set_device('0') # graph, label, n_clusters = load_data( # dataset_name='Cora', # directory='./data', # ) # features = graph.ndata["feat"] # adj_csr = graph.adj_external(scipy_fmt='csr') # edges = graph.edges() # features_lil = sp.lil_matrix(features) # start_time = time.time() # model = DGIKmeans( # in_feats=features.shape[1], # hidden_units=512, # model_filename='dgi_test', # n_epochs=1000, # ) # model.fit(graph=graph, n_clusters=n_clusters, device=device) # res = model.get_memberships() # elapsed_time = time.time() - start_time # ( # ARI_score, # NMI_score, # ACC_score, # Micro_F1_score, # Macro_F1_score, # ) = evaluation(label, res) # print("\n" # f"Elapsed Time:{elapsed_time:.2f}s\n" # f"ARI:{ARI_score}\n" # f"NMI:{ NMI_score}\n" # f"ACC:{ACC_score}\n" # f"Micro F1:{Micro_F1_score}\n" # f"Macro F1:{Macro_F1_score}\n")