"""An implementation of `"DFCN"
<https://ojs.aaai.org/index.php/AAAI/article/view/17198>`_
from the AAAI'21 paper "Deep Fusion Clustering Network".
An interdependency learning-based Structure and Attribute
Information Fusion (SAIF) module is proposed to explicitly merge
the representations learned by an autoencoder and a graph autoencoder
for consensus representation learning.
Also, a reliable target distribution generation measure and a triplet
self-supervision strategy, which facilitate cross-modality information
exploitation, are designed for network training.
"""
# pylint: disable=no-self-use,duplicate-code,W0223
import argparse
import dgl
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from torch import nn
from torch.nn import Parameter
from torch.optim import Adam
from torch.utils.data import DataLoader
from ....utils import sparse_mx_to_torch_sparse_tensor
from ....utils.evaluation import evaluation
from ....utils.load_data import AE_LoadDataset
from ....utils.normalization import symmetrically_normalize_adj
from ...node_embedding.saif import SAIF
from ..base import Base
[docs]class DFCN(nn.Module, Base):
"""DFCN.
Args:
graph (dgl.DGLGraph): Graph data in dgl
data (torch.Tensor): node's features
label (torch.Tensor): node's label
n_clusters (int): numbers of clusters
n_node (int, optional): number of nodes. Defaults to None.
device (torch.device, optional): device. Defaults to None.
args (argparse.Namespace):all parameters
"""
def __init__(
self,
graph: dgl.DGLGraph,
data: torch.Tensor,
label: torch.Tensor,
n_clusters: int,
n_node: int,
device: torch.device,
args: argparse.Namespace,
):
nn.Module.__init__(self)
Base.__init__(self)
self.lr = args.lr
self.gamma_value = args.gamma_value
self.device = device
self.n_clusters = n_clusters
self.label = label
self.lambda_value = args.lambda_value
self.batch_size = args.batch_size
self.shuffle = args.shuffle
self.v = args.freedom_degree
self.estop_steps = args.early_stop
pca = PCA(n_components=args.n_input)
X_pca = pca.fit_transform(data)
dataset = AE_LoadDataset(X_pca)
self.train_loader = DataLoader(dataset,
batch_size=self.batch_size,
shuffle=self.shuffle)
self.data = torch.Tensor(dataset.x).to(device)
self.graph = graph
u1, v1 = self.graph.edges()
adj = graph.adj_external(scipy_fmt="csr")
adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
adj = adj + sp.eye(adj.shape[0])
# get no normalize dgl graph
self.adj_orig_graph = dgl.from_scipy(adj).to(device)
# get symmetrically normalize adj
adj = symmetrically_normalize_adj(adj)
self.adj = sparse_mx_to_torch_sparse_tensor(adj).to(device)
self.saif = SAIF(
self.adj_orig_graph,
self.data,
self.train_loader,
self.label,
self.adj,
self.n_clusters,
n_node=n_node,
device=self.device,
args=args,
)
self.saif.fit(args.saif_epochs)
# get best pre-train the AE and IGAE model
self.ae = self.saif.ae
self.gae = self.saif.igae
self.a = nn.Parameter(nn.init.constant_(torch.zeros(n_node, args.n_z),
0.5),
requires_grad=True).to(device)
self.b = 1 - self.a
self.cluster_layer = nn.Parameter(torch.Tensor(n_clusters, args.n_z),
requires_grad=True)
torch.nn.init.xavier_normal_(self.cluster_layer.data)
self.v = args.freedom_degree
self.gamma = Parameter(torch.zeros(1))
[docs] def forward(self):
"""Forward Propagation
Returns:
x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder
z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder
adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder
q (torch.Tensor):Soft assignment distribution of the fused representations
q1 (torch.Tensor):Soft assignment distribution of IGAE
q2 (torch.Tensor):Soft assignment distribution of AE
z_tilde (torch.Tensor):Clustering embedding
"""
z_ae, _, _, _ = self.ae.encoder(self.data)
z_igae, z_igae_adj = self.gae.encoder(self.adj_orig_graph, self.data)
z_i = self.a * z_ae + self.b * z_igae
z_l = torch.spmm(self.adj, z_i)
s = torch.mm(z_l, z_l.t())
s = F.softmax(s, dim=1)
z_g = torch.mm(s, z_l)
z_tilde = self.gamma * z_g + z_l
x_hat = self.ae.decoder(z_tilde)
z_hat, z_hat_adj = self.gae.decoder(self.adj_orig_graph, z_tilde)
adj_hat = z_igae_adj + z_hat_adj
q = 1.0 / (1.0 + torch.sum(
torch.pow(
(z_tilde).unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
q = q.pow((self.v + 1.0) / 2.0)
q = (q.t() / torch.sum(q, 1)).t()
q1 = 1.0 / (1.0 + torch.sum(
torch.pow(z_ae.unsqueeze(1) - self.cluster_layer, 2), 2) / self.v)
q1 = q1.pow((self.v + 1.0) / 2.0)
q1 = (q1.t() / torch.sum(q1, 1)).t()
q2 = 1.0 / (1.0 + torch.sum(
torch.pow(z_igae.unsqueeze(1) - self.cluster_layer, 2), 2) /
self.v)
q2 = q2.pow((self.v + 1.0) / 2.0)
q2 = (q2.t() / torch.sum(q2, 1)).t()
return x_hat, z_hat, adj_hat, q, q1, q2, z_tilde
def _target_distribution(self, q):
"""Calculate t distribution
Args:
q (torch.Tensor):Soft assignment distribution
Returns:
(torch.Tensor):Target distribution
"""
weight = q**2 / q.sum(0)
return (weight.t() / weight.sum(1)).t()
[docs] def get_embedding(self):
"""Get cluster embedding.
Returns:torch.Tensor
"""
return torch.Tensor(self.embedding)
[docs] def get_memberships(self):
"""Get cluster membership.
Returns:numpy.ndarray
"""
kmeans = KMeans(n_clusters=self.n_clusters,
n_init=20).fit(self.embedding)
return kmeans.labels_
[docs] def fit(self, epochs):
"""Fitting a DFCN clustering model.
Args:
epochs (int): number of train epoch
"""
print("------------------Train All Net--------------------")
# guidance of learned centers of different clusters from saif
# with torch.no_grad():
# x_hat, z_hat, z_tilde, adj_hat = self.saif()
kmeans = KMeans(n_clusters=self.n_clusters, n_init=20)
kmeans.fit_predict(self.saif.get_embedding())
self.cluster_layer.data = torch.Tensor(kmeans.cluster_centers_).to(
self.device)
(
ARI_score,
NMI_score,
AMI_score,
ACC_score,
Micro_F1_score,
Macro_F1_score,
purity,
) = evaluation(self.label, kmeans.labels_)
print(
"Initiation",
f":ARI {ARI_score:.4f}",
f", NMI {NMI_score:.4f}",
f", AMI {AMI_score:.4f}",
f", ACC {ACC_score:.4f}",
f", Micro_F1 {Micro_F1_score:.4f}",
f", Macro_F1 {Macro_F1_score:.4f}",
f", purity {purity:.4f}",
)
best_loss = 1e9
cnt = 0
best_epoch = 0
self.to(self.device)
optimizer = Adam(self.parameters(), lr=self.lr)
for epoch in range(epochs):
self.train()
optimizer.zero_grad()
x_hat, z_hat, adj_hat, q, q1, q2, z_tilde = self.forward()
tmp_q = q.data
p = self._target_distribution(tmp_q)
loss_ae = F.mse_loss(x_hat, self.data)
loss_w = F.mse_loss(z_hat, torch.spmm(self.adj, self.data))
loss_a = F.mse_loss(adj_hat, self.adj.to_dense())
loss_igae = loss_w + self.gamma_value * loss_a
loss_kl = F.kl_div((q.log() + q1.log() + q2.log()) / 3,
p,
reduction="batchmean")
# adjust trade-off parameters of loss can improve score
loss = loss_ae + loss_igae + self.lambda_value * loss_kl
cur_loss = loss.item()
loss.backward()
optimizer.step()
# kmeans = KMeans(n_clusters=self.n_clusters,
# n_init=20).fit(z_tilde.data.cpu().numpy())
#
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(self.label, kmeans.labels_)
print(
f"Epoch_{epoch}",
# f":ARI {ARI_score:.4f}",
# f", NMI {NMI_score:.4f}",
# f", ACC {ACC_score:.4f}",
# f", Micro_F1 {Micro_F1_score:.4f}",
# f", Macro_F1 {Macro_F1_score:.4f}",
f", loss {cur_loss}",
)
# early stopping
if cur_loss < best_loss:
cnt = 0
best_epoch = epoch
best_loss = cur_loss
self.embedding = z_tilde.data.cpu().numpy()
# self.memberships = kmeans.labels_
else:
cnt += 1
print(f"loss increase counts:{cnt}")
if cnt >= self.estop_steps:
print(f"early stopping,best epoch:{best_epoch}")
break
print("------------------End Train All Net--------------------")