Source code for egc.model.graph_clustering.disjoint.dfcn

"""An implementation of `"DFCN"
<https://ojs.aaai.org/index.php/AAAI/article/view/17198>`_
from the AAAI'21 paper "Deep Fusion Clustering Network".

An interdependency learning-based Structure and Attribute
Information Fusion (SAIF) module is proposed to explicitly merge
the representations learned by an autoencoder and a graph autoencoder
for consensus representation learning.

Also, a reliable target distribution generation measure and a triplet
self-supervision strategy, which facilitate cross-modality information
exploitation, are designed for network training.
"""
# pylint: disable=no-self-use,duplicate-code,W0223
import argparse

import dgl
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import Parameter
from torch.optim import Adam
from torch.utils.data import DataLoader

from ....utils import sparse_mx_to_torch_sparse_tensor
from ....utils.evaluation import evaluation
from ....utils.load_data import AE_LoadDataset
from ....utils.normalization import symmetrically_normalize_adj
from ...node_embedding.saif import SAIF
from ..base import Base


[docs]class DFCN(nn.Module, Base): """DFCN. Args: graph (dgl.DGLGraph): Graph data in dgl data (torch.Tensor): node's features label (torch.Tensor): node's label n_clusters (int): numbers of clusters n_node (int, optional): number of nodes. Defaults to None. device (torch.device, optional): device. Defaults to None. args (argparse.Namespace):all parameters """ def __init__( self, graph: dgl.DGLGraph, data: torch.Tensor, label: torch.Tensor, n_clusters: int, n_node: int, device: torch.device, args: argparse.Namespace, ): nn.Module.__init__(self) Base.__init__(self) self.lr = args.lr self.gamma_value = args.gamma_value self.device = device self.n_clusters = n_clusters self.label = label self.lambda_value = args.lambda_value self.batch_size = args.batch_size self.shuffle = args.shuffle self.v = args.freedom_degree self.estop_steps = args.early_stop pca = PCA(n_components=args.n_input) X_pca = pca.fit_transform(data) dataset = AE_LoadDataset(X_pca) self.train_loader = DataLoader(dataset, batch_size=self.batch_size, shuffle=self.shuffle) self.data = torch.Tensor(dataset.x).to(device) self.graph = graph u1, v1 = self.graph.edges() adj = graph.adj_external(scipy_fmt="csr") adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj) adj = adj + sp.eye(adj.shape[0]) # get no normalize dgl graph self.adj_orig_graph = dgl.from_scipy(adj).to(device) # get symmetrically normalize adj adj = symmetrically_normalize_adj(adj) self.adj = sparse_mx_to_torch_sparse_tensor(adj).to(device) self.saif = SAIF( self.adj_orig_graph, self.data, self.train_loader, self.label, self.adj, self.n_clusters, n_node=n_node, device=self.device, args=args, ) self.saif.fit(args.saif_epochs) # get best pre-train the AE and IGAE model self.ae = self.saif.ae self.gae = self.saif.igae self.a = nn.Parameter(nn.init.constant_(torch.zeros(n_node, args.n_z), 0.5), requires_grad=True).to(device) self.b = 1 - self.a self.cluster_layer = nn.Parameter(torch.Tensor(n_clusters, args.n_z), requires_grad=True) torch.nn.init.xavier_normal_(self.cluster_layer.data) self.v = args.freedom_degree self.gamma = Parameter(torch.zeros(1))
[docs] def forward(self): """Forward Propagation Returns: x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder q (torch.Tensor):Soft assignment distribution of the fused representations q1 (torch.Tensor):Soft assignment distribution of IGAE q2 (torch.Tensor):Soft assignment distribution of AE z_tilde (torch.Tensor):Clustering embedding """ z_ae, _, _, _ = self.ae.encoder(self.data) z_igae, z_igae_adj = self.gae.encoder(self.adj_orig_graph, self.data) z_i = self.a * z_ae + self.b * z_igae z_l = torch.spmm(self.adj, z_i) s = torch.mm(z_l, z_l.t()) s = F.softmax(s, dim=1) z_g = torch.mm(s, z_l) z_tilde = self.gamma * z_g + z_l x_hat = self.ae.decoder(z_tilde) z_hat, z_hat_adj = self.gae.decoder(self.adj_orig_graph, z_tilde) adj_hat = z_igae_adj + z_hat_adj q = 1.0 / (1.0 + torch.sum( torch.pow( (z_tilde).unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) q = q.pow((self.v + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() q1 = 1.0 / (1.0 + torch.sum( torch.pow(z_ae.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) q1 = q1.pow((self.v + 1.0) / 2.0) q1 = (q1.t() / torch.sum(q1, 1)).t() q2 = 1.0 / (1.0 + torch.sum( torch.pow(z_igae.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v) q2 = q2.pow((self.v + 1.0) / 2.0) q2 = (q2.t() / torch.sum(q2, 1)).t() return x_hat, z_hat, adj_hat, q, q1, q2, z_tilde
def _target_distribution(self, q): """Calculate t distribution Args: q (torch.Tensor):Soft assignment distribution Returns: (torch.Tensor):Target distribution """ weight = q**2 / q.sum(0) return (weight.t() / weight.sum(1)).t()
[docs] def get_embedding(self): """Get cluster embedding. Returns:torch.Tensor """ return torch.Tensor(self.embedding)
[docs] def get_memberships(self): """Get cluster membership. Returns:numpy.ndarray """ kmeans = KMeans(n_clusters=self.n_clusters, n_init=20).fit(self.embedding) return kmeans.labels_
[docs] def fit(self, epochs): """Fitting a DFCN clustering model. Args: epochs (int): number of train epoch """ print("------------------Train All Net--------------------") # guidance of learned centers of different clusters from saif # with torch.no_grad(): # x_hat, z_hat, z_tilde, adj_hat = self.saif() kmeans = KMeans(n_clusters=self.n_clusters, n_init=20) kmeans.fit_predict(self.saif.get_embedding()) self.cluster_layer.data = torch.Tensor(kmeans.cluster_centers_).to( self.device) ( ARI_score, NMI_score, AMI_score, ACC_score, Micro_F1_score, Macro_F1_score, purity, ) = evaluation(self.label, kmeans.labels_) print( "Initiation", f":ARI {ARI_score:.4f}", f", NMI {NMI_score:.4f}", f", AMI {AMI_score:.4f}", f", ACC {ACC_score:.4f}", f", Micro_F1 {Micro_F1_score:.4f}", f", Macro_F1 {Macro_F1_score:.4f}", f", purity {purity:.4f}", ) best_loss = 1e9 cnt = 0 best_epoch = 0 self.to(self.device) optimizer = Adam(self.parameters(), lr=self.lr) for epoch in range(epochs): self.train() optimizer.zero_grad() x_hat, z_hat, adj_hat, q, q1, q2, z_tilde = self.forward() tmp_q = q.data p = self._target_distribution(tmp_q) loss_ae = F.mse_loss(x_hat, self.data) loss_w = F.mse_loss(z_hat, torch.spmm(self.adj, self.data)) loss_a = F.mse_loss(adj_hat, self.adj.to_dense()) loss_igae = loss_w + self.gamma_value * loss_a loss_kl = F.kl_div((q.log() + q1.log() + q2.log()) / 3, p, reduction="batchmean") # adjust trade-off parameters of loss can improve score loss = loss_ae + loss_igae + self.lambda_value * loss_kl cur_loss = loss.item() loss.backward() optimizer.step() # kmeans = KMeans(n_clusters=self.n_clusters, # n_init=20).fit(z_tilde.data.cpu().numpy()) # # ( # ARI_score, # NMI_score, # ACC_score, # Micro_F1_score, # Macro_F1_score, # ) = evaluation(self.label, kmeans.labels_) print( f"Epoch_{epoch}", # f":ARI {ARI_score:.4f}", # f", NMI {NMI_score:.4f}", # f", ACC {ACC_score:.4f}", # f", Micro_F1 {Micro_F1_score:.4f}", # f", Macro_F1 {Macro_F1_score:.4f}", f", loss {cur_loss}", ) # early stopping if cur_loss < best_loss: cnt = 0 best_epoch = epoch best_loss = cur_loss self.embedding = z_tilde.data.cpu().numpy() # self.memberships = kmeans.labels_ else: cnt += 1 print(f"loss increase counts:{cnt}") if cnt >= self.estop_steps: print(f"early stopping,best epoch:{best_epoch}") break print("------------------End Train All Net--------------------")