"""
IGAE Embedding
"""
# pylint: disable=unused-import,too-many-locals
import argparse
from copy import deepcopy
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GraphConv
from sklearn.cluster import KMeans
from torch import nn
from torch.optim import Adam
from ...model.graph_clustering.base import Base
from ...utils.evaluation import evaluation
# from torch.nn import Parameter
# ---------------------------------- IGAE -----------------------------------------
[docs]class IGAE(nn.Module):
"""This is a symmetric improved graph autoencoder (IGAE).
This network requires to reconstruct both the weighted
attribute matrix and the adjacency matrix simultaneously
Args:
args (argparse.Namespace):all parameters
"""
def __init__(self, args: argparse.Namespace, device):
super().__init__()
self.encoder = IGAE_encoder(args)
self.decoder = IGAE_decoder(args)
self.lr = args.lr
self.device = device
self.epochs = args.igae_epochs
self.gamma_value = args.gamma_value
# self.embedding = None
# self.memberships = None
# self.ebest_model = None # early best model
# self.estop_steps = args.early_stop
[docs] def forward(self, g, feat):
"""Forward Propagation
Args:
g (dgl.DGLGraph): Graph data in dgl
feat (torch.Tensor): node's features
Returns:
z_igae (torch.Tensor):Latent embedding of IGAE
z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder
adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder
"""
z_igae, z_igae_adj = self.encoder(g, feat)
z_hat, z_hat_adj = self.decoder(g, z_igae)
adj_hat = z_igae_adj + z_hat_adj
# adj_hat = z_hat_adj
return z_igae, z_hat, adj_hat
# def get_embedding(self):
# """Get cluster embedding.
#
# Returns:numpy.ndarray
#
# """
# return self.embedding
#
# def get_memberships(self):
# """Get cluster membership.
#
# Returns:numpy.ndarray
#
# """
# return self.memberships
#
# def get_best_model(self):
# """Get best model by early stopping.
#
# Returns:nn.Module
#
# """
# return self.ebest_model
[docs] def fit(self, g, data, adj):
"""Fitting a IGAE clustering model.
Args:
g (dgl.DGLGraph): Graph data in dgl
data (torch.Tensor): node's features
adj (sp.csr.csr_matrix): adjacency matrix
"""
print("------------------Pretrain IGAE--------------------")
# original_acc = -1
# cnt = 0
# best_epoch = 0
self.to(self.device)
optimizer = Adam(self.parameters(), lr=self.lr)
for epoch in range(self.epochs):
self.train()
optimizer.zero_grad()
_, z_hat, adj_hat = self.forward(g, data)
loss_w = F.mse_loss(z_hat, torch.spmm(adj, data))
loss_a = F.mse_loss(adj_hat, adj.to_dense())
loss_igae = loss_w + self.gamma_value * loss_a
loss_igae.backward()
cur_loss = loss_igae.item()
optimizer.step()
# kmeans = KMeans(n_clusters=n_clusters,
# n_init=20).fit(z_igae.data.cpu().numpy())
#
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(label, kmeans.labels_)
print(
f"EPOCH_{epoch}",
# f":ARI {ARI_score:.4f}",
# f", NMI {NMI_score:.4f}",
# f", ACC {ACC_score:.4f}",
# f", Micro_F1 {Micro_F1_score:.4f}",
# f", Macro_F1 {Macro_F1_score:.4f}",
f", Loss {cur_loss}",
)
# early stopping
# if ACC_score > original_acc:
# cnt = 0
# best_epoch = epoch
# original_acc = ACC_score
# self.embedding = z_igae.data.cpu().numpy()
# self.memberships = kmeans.labels_
# self.ebest_model = deepcopy(self.to(device))
# # print("best_model saved")
# else:
# cnt += 1
# print(f"Acc drops count:{cnt}")
# if cnt >= self.estop_steps:
# print(f"early stopping,best epoch:{best_epoch}")
# break
print("------------------End Pretrain IGAE--------------------")
# class GNNLayer(nn.Module):
# """gnn layer for IGAE
#
# Args:
# in_features (int): Number of input features' dimension
# out_features (int): Number of output features' dimension
#
# """
# def __init__(self, in_features: int, out_features: int):
# super().__init__()
# self.in_features = in_features
# self.out_features = out_features
# self.act = nn.Tanh()
# self.weight = Parameter(torch.FloatTensor(in_features, out_features))
# torch.nn.init.xavier_uniform_(self.weight)
#
# def forward(self, features, adj, active=False):
# """Forward Propagation
#
# Args:
# features (torch.Tensor):node's features
# adj (sp.csr.csr_matrix): adjacency matrix
# active (bool):Whether to use the activation function
#
# Returns:
# output (int):output features
#
# """
# if active:
# support = self.act(torch.mm(features, self.weight))
# else:
# support = torch.mm(features, self.weight)
# output = torch.spmm(adj, support)
# return output
[docs]class IGAE_encoder(nn.Module):
"""Encoder for IGAE
Args:
args (argparse.Namespace):all parameters
"""
def __init__(self, args: argparse.Namespace):
super().__init__()
self.gnn_1 = GraphConv(args.n_input,
args.gae_n_enc_1,
activation=nn.Tanh())
self.gnn_2 = GraphConv(args.gae_n_enc_1,
args.gae_n_enc_2,
activation=nn.Tanh())
self.gnn_3 = GraphConv(args.gae_n_enc_2, args.gae_n_enc_3)
self.s = nn.Sigmoid()
[docs] def forward(self, g, feat):
"""Forward Propagation
Args:
g (dgl.DGLGraph): Graph data in dgl
feat (torch.Tensor): node's features
Returns:
z_igae (torch.Tensor):Latent embedding of IGAE
z_igae_adj (torch.Tensor):Reconstructed adjacency matrix generated by IGAE encoder
"""
z = self.gnn_1(g, feat)
z = self.gnn_2(g, z)
z_igae = self.gnn_3(g, z)
z_igae_adj = self.s(torch.mm(z_igae, z_igae.t()))
return z_igae, z_igae_adj
[docs]class IGAE_decoder(nn.Module):
"""Decoder for IGAE
Args:
args (argparse.Namespace):all parameters
"""
def __init__(self, args: argparse.Namespace):
super().__init__()
self.gnn_4 = GraphConv(args.gae_n_dec_1,
args.gae_n_dec_2,
activation=nn.Tanh())
self.gnn_5 = GraphConv(args.gae_n_dec_2,
args.gae_n_dec_3,
activation=nn.Tanh())
self.gnn_6 = GraphConv(args.gae_n_dec_3,
args.n_input,
activation=nn.Tanh())
self.s = nn.Sigmoid()
[docs] def forward(self, g, z_igae):
"""Forward Propagation
Args:
g (dgl.DGLGraph): Graph data in dgl
z_igae (torch.Tensor):Latent embedding of IGAE
Returns:
z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder
z_hat_adj (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder
"""
z = self.gnn_4(g, z_igae)
z = self.gnn_5(g, z)
z_hat = self.gnn_6(g, z)
z_hat_adj = self.s(torch.mm(z_hat, z_hat.t()))
return z_hat, z_hat_adj