Source code for egc.model.node_embedding.igae

"""
IGAE Embedding
"""
# pylint: disable=unused-import,too-many-locals
import argparse
from copy import deepcopy

import scipy.sparse as sp
import torch
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GraphConv
from sklearn.cluster import KMeans
from torch import nn
from torch.optim import Adam

from ...model.graph_clustering.base import Base
from ...utils.evaluation import evaluation

# from torch.nn import Parameter


# ---------------------------------- IGAE -----------------------------------------
[docs]class IGAE(nn.Module): """This is a symmetric improved graph autoencoder (IGAE). This network requires to reconstruct both the weighted attribute matrix and the adjacency matrix simultaneously Args: args (argparse.Namespace):all parameters """ def __init__(self, args: argparse.Namespace, device): super().__init__() self.encoder = IGAE_encoder(args) self.decoder = IGAE_decoder(args) self.lr = args.lr self.device = device self.epochs = args.igae_epochs self.gamma_value = args.gamma_value # self.embedding = None # self.memberships = None # self.ebest_model = None # early best model # self.estop_steps = args.early_stop
[docs] def forward(self, g, feat): """Forward Propagation Args: g (dgl.DGLGraph): Graph data in dgl feat (torch.Tensor): node's features Returns: z_igae (torch.Tensor):Latent embedding of IGAE z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder adj_hat (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder """ z_igae, z_igae_adj = self.encoder(g, feat) z_hat, z_hat_adj = self.decoder(g, z_igae) adj_hat = z_igae_adj + z_hat_adj # adj_hat = z_hat_adj return z_igae, z_hat, adj_hat
# def get_embedding(self): # """Get cluster embedding. # # Returns:numpy.ndarray # # """ # return self.embedding # # def get_memberships(self): # """Get cluster membership. # # Returns:numpy.ndarray # # """ # return self.memberships # # def get_best_model(self): # """Get best model by early stopping. # # Returns:nn.Module # # """ # return self.ebest_model
[docs] def fit(self, g, data, adj): """Fitting a IGAE clustering model. Args: g (dgl.DGLGraph): Graph data in dgl data (torch.Tensor): node's features adj (sp.csr.csr_matrix): adjacency matrix """ print("------------------Pretrain IGAE--------------------") # original_acc = -1 # cnt = 0 # best_epoch = 0 self.to(self.device) optimizer = Adam(self.parameters(), lr=self.lr) for epoch in range(self.epochs): self.train() optimizer.zero_grad() _, z_hat, adj_hat = self.forward(g, data) loss_w = F.mse_loss(z_hat, torch.spmm(adj, data)) loss_a = F.mse_loss(adj_hat, adj.to_dense()) loss_igae = loss_w + self.gamma_value * loss_a loss_igae.backward() cur_loss = loss_igae.item() optimizer.step() # kmeans = KMeans(n_clusters=n_clusters, # n_init=20).fit(z_igae.data.cpu().numpy()) # # ( # ARI_score, # NMI_score, # ACC_score, # Micro_F1_score, # Macro_F1_score, # ) = evaluation(label, kmeans.labels_) print( f"EPOCH_{epoch}", # f":ARI {ARI_score:.4f}", # f", NMI {NMI_score:.4f}", # f", ACC {ACC_score:.4f}", # f", Micro_F1 {Micro_F1_score:.4f}", # f", Macro_F1 {Macro_F1_score:.4f}", f", Loss {cur_loss}", ) # early stopping # if ACC_score > original_acc: # cnt = 0 # best_epoch = epoch # original_acc = ACC_score # self.embedding = z_igae.data.cpu().numpy() # self.memberships = kmeans.labels_ # self.ebest_model = deepcopy(self.to(device)) # # print("best_model saved") # else: # cnt += 1 # print(f"Acc drops count:{cnt}") # if cnt >= self.estop_steps: # print(f"early stopping,best epoch:{best_epoch}") # break print("------------------End Pretrain IGAE--------------------")
# class GNNLayer(nn.Module): # """gnn layer for IGAE # # Args: # in_features (int): Number of input features' dimension # out_features (int): Number of output features' dimension # # """ # def __init__(self, in_features: int, out_features: int): # super().__init__() # self.in_features = in_features # self.out_features = out_features # self.act = nn.Tanh() # self.weight = Parameter(torch.FloatTensor(in_features, out_features)) # torch.nn.init.xavier_uniform_(self.weight) # # def forward(self, features, adj, active=False): # """Forward Propagation # # Args: # features (torch.Tensor):node's features # adj (sp.csr.csr_matrix): adjacency matrix # active (bool):Whether to use the activation function # # Returns: # output (int):output features # # """ # if active: # support = self.act(torch.mm(features, self.weight)) # else: # support = torch.mm(features, self.weight) # output = torch.spmm(adj, support) # return output
[docs]class IGAE_encoder(nn.Module): """Encoder for IGAE Args: args (argparse.Namespace):all parameters """ def __init__(self, args: argparse.Namespace): super().__init__() self.gnn_1 = GraphConv(args.n_input, args.gae_n_enc_1, activation=nn.Tanh()) self.gnn_2 = GraphConv(args.gae_n_enc_1, args.gae_n_enc_2, activation=nn.Tanh()) self.gnn_3 = GraphConv(args.gae_n_enc_2, args.gae_n_enc_3) self.s = nn.Sigmoid()
[docs] def forward(self, g, feat): """Forward Propagation Args: g (dgl.DGLGraph): Graph data in dgl feat (torch.Tensor): node's features Returns: z_igae (torch.Tensor):Latent embedding of IGAE z_igae_adj (torch.Tensor):Reconstructed adjacency matrix generated by IGAE encoder """ z = self.gnn_1(g, feat) z = self.gnn_2(g, z) z_igae = self.gnn_3(g, z) z_igae_adj = self.s(torch.mm(z_igae, z_igae.t())) return z_igae, z_igae_adj
[docs]class IGAE_decoder(nn.Module): """Decoder for IGAE Args: args (argparse.Namespace):all parameters """ def __init__(self, args: argparse.Namespace): super().__init__() self.gnn_4 = GraphConv(args.gae_n_dec_1, args.gae_n_dec_2, activation=nn.Tanh()) self.gnn_5 = GraphConv(args.gae_n_dec_2, args.gae_n_dec_3, activation=nn.Tanh()) self.gnn_6 = GraphConv(args.gae_n_dec_3, args.n_input, activation=nn.Tanh()) self.s = nn.Sigmoid()
[docs] def forward(self, g, z_igae): """Forward Propagation Args: g (dgl.DGLGraph): Graph data in dgl z_igae (torch.Tensor):Latent embedding of IGAE Returns: z_hat (torch.Tensor):Reconstructed weighted attribute matrix generated by IGAE decoder z_hat_adj (torch.Tensor):Reconstructed adjacency matrix generated by IGAE decoder """ z = self.gnn_4(g, z_igae) z = self.gnn_5(g, z) z_hat = self.gnn_6(g, z) z_hat_adj = self.s(torch.mm(z_hat, z_hat.t())) return z_hat, z_hat_adj