"""
AE Embedding
"""
# pylint: disable=unused-import
import argparse
from copy import deepcopy
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from torch import nn
from torch.nn import Linear
from torch.optim import Adam
from torch.utils.data import DataLoader
from ...utils.evaluation import evaluation
# from model.graph_clustering.base import Base
# from utils.load_data import AE_LoadDataset
# ---------------------------------- AE -----------------------------------------
[docs]class AE(nn.Module):
"""AutoEncoder Model
Args:
n_input (int): dim of features
n_clusters (int): cluster num.
hidden1 (int): hidden units size of encode1.
hidden2 (int): hidden units size of encode2.
hidden3 (int): hidden units size of encode3.
hidden4 (int): hidden units size of decode1.
hidden5 (int): hidden units size of decode2.
hidden6 (int): hidden units size of decode3.
lr (float, optional): learning rate.. Defaults to 0.001.
epochs (int, optional): number of embedding training epochs. Defaults to 200.
n_z (int): Number of Z's dimensions. Default is 20.
activation (str, optional): activation of gcn layer_1. Defaults to 'relu'.
early_stop (bool): steps' numbers of early stop.
if_eva (bool): if use kmean to judge the embedding quality.
if_early_stop (bool): if use early stop.
"""
def __init__(
self,
n_input: int,
n_clusters: int,
hidden1: int = 500,
hidden2: int = 500,
hidden3: int = 2000,
hidden4: int = 2000,
hidden5: int = 500,
hidden6: int = 500,
lr: float = 0.0005,
epochs: int = 100,
n_z: int = 10,
activation: str = "relu",
early_stop: int = 20,
if_eva: bool = False,
if_early_stop: bool = False,
):
super().__init__()
if activation == "leakyrelu":
self.activation = nn.LeakyReLU(0.2, inplace=True)
elif activation == "relu":
self.activation = nn.ReLU()
else:
self.activation = activation
self.encoder = AE_encoder(
n_input=n_input,
hidden1=hidden1,
hidden2=hidden2,
hidden3=hidden3,
n_z=n_z,
activation=self.activation,
)
self.decoder = AE_decoder(
n_input=n_input,
hidden1=hidden4,
hidden2=hidden5,
hidden3=hidden6,
n_z=n_z,
activation=self.activation,
)
self.embedding = None
self.memberships = None
self.ebest_model = None # early best model
self.estop_steps = early_stop
self.lr = lr
self.epochs = epochs
self.n_clusters = n_clusters
self.if_eva = if_eva
self.if_early_stop = if_early_stop
[docs] def forward(self, x):
"""Forward Propagation
Args:
x (torch.Tensor):node's features
Returns:
x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder
z_ae (torch.Tensor):Latent embedding of AE
"""
z_ae, z_1, z_2, z_3 = self.encoder(x)
x_hat = self.decoder(z_ae)
return x_hat, z_ae, z_1, z_2, z_3
# def get_embedding(self):
# """Get cluster embedding.
#
# Returns:numpy.ndarray
#
# """
# return self.embedding
#
# def get_memberships(self):
# """Get cluster membership.
#
# Returns:numpy.ndarray
#
# """
# return self.memberships
#
# def get_best_model(self):
# """Get best model by early stopping.
#
# Returns:nn.Module
#
# """
# return self.ebest_model
[docs] def fit(self, data, train_loader, label) -> None:
"""Fitting a AE clustering model.
Args:
data (torch.Tensor): node's features
train_loader (DataLoader): DataLoader of AE train
label (torch.Tensor): node's label
"""
if torch.cuda.is_available():
self.encoder = self.encoder.cuda()
self.decoder = self.decoder.cuda()
print("------------------Pretrain AE--------------------")
original_acc = -1
cnt = 0
best_epoch = 0
# for name, param in self.named_parameters():
# if param.requires_grad:
# print(name)
optimizer = Adam(self.parameters(), lr=self.lr)
for epoch in range(self.epochs):
loss_list = []
for x, _ in train_loader:
self.train()
if torch.cuda.is_available():
x = x.cuda()
x_hat, _, _, _, _ = self.forward(x)
loss = F.mse_loss(x_hat, x)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
print(f"epoch:{epoch} loss:{np.mean(loss_list)}")
if self.if_eva:
with torch.no_grad():
x_hat, z_ae, _, _, _ = self.forward(
data.cuda() if torch.cuda.is_available() else data)
kmeans = KMeans(n_clusters=self.n_clusters,
n_init=20).fit(z_ae.data.cpu().numpy())
(
ARI_score,
NMI_score,
AMI_score,
ACC_score,
Micro_F1_score,
Macro_F1_score,
purity,
) = evaluation(label, kmeans.labels_)
print(
f"epoch_{epoch}",
f":ARI {ARI_score:.4f}",
f", NMI {NMI_score:.4f}",
f", AMI {AMI_score:.4f}",
f", ACC {ACC_score:.4f}",
f", Micro_F1 {Micro_F1_score:.4f}",
f", Macro_F1 {Macro_F1_score:.4f}",
f", purity {purity:.4f}",
)
if self.if_early_stop:
# early stopping
if ACC_score > original_acc:
cnt = 0
best_epoch = epoch
original_acc = ACC_score
self.embedding = z_ae.data.cpu().numpy()
self.memberships = kmeans.labels_
self.ebest_model = deepcopy(self)
# print("best_model saved")
else:
cnt += 1
print(f"Acc drops count:{cnt}")
if cnt >= self.estop_steps:
print(f"early stopping,best epoch:{best_epoch}")
break
print("------------------End Pretrain AE--------------------")
[docs]class AE_encoder(nn.Module):
"""Encoder for AE
Args:
args (argparse.Namespace):all parameters
"""
def __init__(
self,
n_input: int,
hidden1: int,
hidden2: int,
hidden3: int,
n_z: int,
activation: object,
):
super().__init__()
self.enc_1 = Linear(n_input, hidden1, bias=False)
self.enc_2 = Linear(hidden1, hidden2, bias=False)
self.enc_3 = Linear(hidden2, hidden3, bias=False)
self.z_layer = Linear(hidden3, n_z, bias=False)
self.activation = activation
[docs] def forward(self, x):
"""Forward Propagation
Args:
x (torch.Tensor):node's features
Returns:
z_ae (torch.Tensor):Latent embedding of AE
"""
z_1 = self.activation(self.enc_1(x))
z_2 = self.activation(self.enc_2(z_1))
z_3 = self.activation(self.enc_3(z_2))
z_ae = self.z_layer(z_3)
return z_ae, z_1, z_2, z_3
[docs]class AE_decoder(nn.Module):
"""Decoder for AE
Args:
args (argparse.Namespace):all parameters
"""
def __init__(
self,
n_input: int,
hidden1: int,
hidden2: int,
hidden3: int,
n_z: int,
activation: object,
):
super().__init__()
self.dec_1 = Linear(n_z, hidden1, bias=False)
self.dec_2 = Linear(hidden1, hidden2, bias=False)
self.dec_3 = Linear(hidden2, hidden3, bias=False)
self.x_bar_layer = Linear(hidden3, n_input, bias=False)
self.activation = activation
[docs] def forward(self, z_ae):
"""Forward Propagation
Args:
z_ae (torch.Tensor):Latent embedding of AE
Returns:
x_hat (torch.Tensor):Reconstructed attribute matrix generated by AE decoder
"""
z = self.activation(self.dec_1(z_ae))
z = self.activation(self.dec_2(z))
z = self.activation(self.dec_3(z))
x_hat = self.x_bar_layer(z)
return x_hat