Source code for egc.model.node_embedding.saif

"""
A structure and attribute information fusion (SAIF) module
"""
import argparse

import dgl
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from torch.optim import Adam
from torch.utils.data import DataLoader

from ...model.node_embedding.ae import AE
from ...model.node_embedding.igae import IGAE

# from sklearn.cluster import KMeans
# from utils.evaluation import evaluation


# ---------------------------------- SAIF -----------------------------------------
[docs]class SAIF(nn.Module): """A structure and attribute information fusion (SAIF) module Args: adj_orig_graph (dgl.DGLGraph): Graph data in dgl data (torch.Tensor): node's features train_loader (DataLoader): DataLoader of AE train label (torch.Tensor): node's label adj (sp.csr.csr_matrix): adjacency matrix n_clusters (int): numbers of clusters n_node (int, optional): number of nodes. Defaults to None. device (torch.device, optional): device. Defaults to None. args (argparse.Namespace):all parameters """ def __init__( self, adj_orig_graph: dgl.DGLGraph, data: torch.Tensor, train_loader: DataLoader, label: torch.Tensor, adj: sp.csr.csr_matrix, n_clusters: int, n_node: int, device: torch.device, args: argparse.Namespace, ): super().__init__() self.adj_orig_graph = adj_orig_graph self.data = data self.adj = adj self.train_loader = train_loader self.lr = args.lr self.label = label self.device = device self.gamma_value = args.gamma_value self.n_clusters = n_clusters self.estop_steps = args.early_stop self.embedding = None self.ae = AE( n_input=args.n_input, n_clusters=self.n_clusters, hidden1=args.ae_n_enc_1, hidden2=args.ae_n_enc_2, hidden3=args.ae_n_enc_3, hidden4=args.ae_n_dec_1, hidden5=args.ae_n_dec_2, hidden6=args.ae_n_dec_3, lr=args.lr, epochs=args.ae_epochs, # device=self.device, n_z=args.n_z, activation="leakyrelu", early_stop=args.early_stop, if_eva=False, if_early_stop=False, ) self.igae = IGAE(args, self.device) self.a = nn.Parameter(nn.init.constant_(torch.zeros(n_node, args.n_z), 0.5), requires_grad=True).to(device) self.b = 1 - self.a self.gamma = Parameter(torch.zeros(1)) # pre-train the AE and IGAE independently for 30 iterations self.ae.fit(self.data, self.train_loader, self.label) self.igae.fit( self.adj_orig_graph, self.data, self.adj, ) # get best pre-train the AE and IGAE model # self.ae = self.ae.get_best_model() # self.igae = self.igae.get_best_model()
[docs] def forward(self): """Forward Propagation Returns: x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder z_tilde (torch.Tensor):Clustering embedding adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder """ z_ae, _, _, _ = self.ae.encoder(self.data) z_igae, z_igae_adj = self.igae.encoder(self.adj_orig_graph, self.data) z_i = self.a * z_ae + self.b * z_igae z_l = torch.spmm(self.adj, z_i) s = torch.mm(z_l, z_l.t()) s = F.softmax(s, dim=1) z_g = torch.mm(s, z_l) z_tilde = self.gamma * z_g + z_l x_hat = self.ae.decoder(z_tilde) z_hat, z_hat_adj = self.igae.decoder(self.adj_orig_graph, z_tilde) adj_hat = z_igae_adj + z_hat_adj # adj_hat = z_hat_adj return x_hat, z_hat, z_tilde, adj_hat
[docs] def get_embedding(self): """Get cluster embedding. Returns:numpy.ndarray """ return self.embedding
# def get_memberships(self): # """Get cluster membership. # # Returns:numpy.ndarray # # """ # return self.memberships
[docs] def fit(self, epochs): """Fitting a SAIF clustering model. Args: epochs (int): number of train epoch """ print("------------------Pretrain SAIF--------------------") best_loss = 1e9 cnt = 0 best_epoch = 0 self.to(self.device) optimizer = Adam(self.parameters(), lr=self.lr) for epoch in range(epochs): self.train() optimizer.zero_grad() x_hat, z_hat, z_tilde, adj_hat = self.forward() loss_ae = F.mse_loss(x_hat, self.data) loss_w = F.mse_loss(z_hat, torch.spmm(self.adj, self.data)) loss_a = F.mse_loss(adj_hat, self.adj.to_dense()) loss_igae = 0.02 * loss_w + self.gamma_value * loss_a # loss_igae = loss_w + self.gamma_value * loss_a loss = loss_ae + loss_igae loss.backward() cur_loss = loss.item() optimizer.step() # kmeans = KMeans(n_clusters=self.n_clusters, # n_init=20).fit(z_tilde.data.cpu().numpy()) # ( # ARI_score, # NMI_score, # ACC_score, # Micro_F1_score, # Macro_F1_score, # ) = evaluation(self.label, kmeans.labels_) print( f"Epoch_{epoch}", # f":ARI {ARI_score:.4f}", # f", NMI {NMI_score:.4f}", # f", ACC {ACC_score:.4f}", # f", Micro_F1 {Micro_F1_score:.4f}", # f", Macro_F1 {Macro_F1_score:.4f}", f", Loss {cur_loss}", ) # early stopping if cur_loss < best_loss: cnt = 0 best_epoch = epoch best_loss = cur_loss self.embedding = z_tilde.detach().cpu().numpy() # self.memberships = kmeans.labels_ else: cnt += 1 print(f"loss increase count:{cnt}") if cnt >= self.estop_steps: print(f"early stopping,best epoch:{best_epoch}") break print("------------------End Pretrain SAIF--------------------")