"""
A structure and attribute information fusion (SAIF) module
"""
import argparse
import dgl
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from torch.optim import Adam
from torch.utils.data import DataLoader
from ...model.node_embedding.ae import AE
from ...model.node_embedding.igae import IGAE
# from sklearn.cluster import KMeans
# from utils.evaluation import evaluation
# ---------------------------------- SAIF -----------------------------------------
[docs]class SAIF(nn.Module):
"""A structure and attribute information fusion (SAIF) module
Args:
adj_orig_graph (dgl.DGLGraph): Graph data in dgl
data (torch.Tensor): node's features
train_loader (DataLoader): DataLoader of AE train
label (torch.Tensor): node's label
adj (sp.csr.csr_matrix): adjacency matrix
n_clusters (int): numbers of clusters
n_node (int, optional): number of nodes. Defaults to None.
device (torch.device, optional): device. Defaults to None.
args (argparse.Namespace):all parameters
"""
def __init__(
self,
adj_orig_graph: dgl.DGLGraph,
data: torch.Tensor,
train_loader: DataLoader,
label: torch.Tensor,
adj: sp.csr.csr_matrix,
n_clusters: int,
n_node: int,
device: torch.device,
args: argparse.Namespace,
):
super().__init__()
self.adj_orig_graph = adj_orig_graph
self.data = data
self.adj = adj
self.train_loader = train_loader
self.lr = args.lr
self.label = label
self.device = device
self.gamma_value = args.gamma_value
self.n_clusters = n_clusters
self.estop_steps = args.early_stop
self.embedding = None
self.ae = AE(
n_input=args.n_input,
n_clusters=self.n_clusters,
hidden1=args.ae_n_enc_1,
hidden2=args.ae_n_enc_2,
hidden3=args.ae_n_enc_3,
hidden4=args.ae_n_dec_1,
hidden5=args.ae_n_dec_2,
hidden6=args.ae_n_dec_3,
lr=args.lr,
epochs=args.ae_epochs,
# device=self.device,
n_z=args.n_z,
activation="leakyrelu",
early_stop=args.early_stop,
if_eva=False,
if_early_stop=False,
)
self.igae = IGAE(args, self.device)
self.a = nn.Parameter(nn.init.constant_(torch.zeros(n_node, args.n_z),
0.5),
requires_grad=True).to(device)
self.b = 1 - self.a
self.gamma = Parameter(torch.zeros(1))
# pre-train the AE and IGAE independently for 30 iterations
self.ae.fit(self.data, self.train_loader, self.label)
self.igae.fit(
self.adj_orig_graph,
self.data,
self.adj,
)
# get best pre-train the AE and IGAE model
# self.ae = self.ae.get_best_model()
# self.igae = self.igae.get_best_model()
[docs] def forward(self):
"""Forward Propagation
Returns:
x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder
z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder
z_tilde (torch.Tensor):Clustering embedding
adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder
"""
z_ae, _, _, _ = self.ae.encoder(self.data)
z_igae, z_igae_adj = self.igae.encoder(self.adj_orig_graph, self.data)
z_i = self.a * z_ae + self.b * z_igae
z_l = torch.spmm(self.adj, z_i)
s = torch.mm(z_l, z_l.t())
s = F.softmax(s, dim=1)
z_g = torch.mm(s, z_l)
z_tilde = self.gamma * z_g + z_l
x_hat = self.ae.decoder(z_tilde)
z_hat, z_hat_adj = self.igae.decoder(self.adj_orig_graph, z_tilde)
adj_hat = z_igae_adj + z_hat_adj
# adj_hat = z_hat_adj
return x_hat, z_hat, z_tilde, adj_hat
[docs] def get_embedding(self):
"""Get cluster embedding.
Returns:numpy.ndarray
"""
return self.embedding
# def get_memberships(self):
# """Get cluster membership.
#
# Returns:numpy.ndarray
#
# """
# return self.memberships
[docs] def fit(self, epochs):
"""Fitting a SAIF clustering model.
Args:
epochs (int): number of train epoch
"""
print("------------------Pretrain SAIF--------------------")
best_loss = 1e9
cnt = 0
best_epoch = 0
self.to(self.device)
optimizer = Adam(self.parameters(), lr=self.lr)
for epoch in range(epochs):
self.train()
optimizer.zero_grad()
x_hat, z_hat, z_tilde, adj_hat = self.forward()
loss_ae = F.mse_loss(x_hat, self.data)
loss_w = F.mse_loss(z_hat, torch.spmm(self.adj, self.data))
loss_a = F.mse_loss(adj_hat, self.adj.to_dense())
loss_igae = 0.02 * loss_w + self.gamma_value * loss_a
# loss_igae = loss_w + self.gamma_value * loss_a
loss = loss_ae + loss_igae
loss.backward()
cur_loss = loss.item()
optimizer.step()
# kmeans = KMeans(n_clusters=self.n_clusters,
# n_init=20).fit(z_tilde.data.cpu().numpy())
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(self.label, kmeans.labels_)
print(
f"Epoch_{epoch}",
# f":ARI {ARI_score:.4f}",
# f", NMI {NMI_score:.4f}",
# f", ACC {ACC_score:.4f}",
# f", Micro_F1 {Micro_F1_score:.4f}",
# f", Macro_F1 {Macro_F1_score:.4f}",
f", Loss {cur_loss}",
)
# early stopping
if cur_loss < best_loss:
cnt = 0
best_epoch = epoch
best_loss = cur_loss
self.embedding = z_tilde.detach().cpu().numpy()
# self.memberships = kmeans.labels_
else:
cnt += 1
print(f"loss increase count:{cnt}")
if cnt >= self.estop_steps:
print(f"early stopping,best epoch:{best_epoch}")
break
print("------------------End Pretrain SAIF--------------------")