"""Deep Graph Clustering"""
from typing import List
import dgl
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from ....model.graph_clustering.base import Base
from ....module import InnerProductDecoder
from ....module import MultiLayerDNN
from ....module import MultiLayerGNN
from ....utils import init_weights
from ....utils import load_model
from ....utils import NaiveDataLoader
from ....utils import save_model
from ....utils import sparse_mx_to_torch_sparse_tensor
from ....utils import torch_sparse_to_dgl_graph
[docs]class DGC(Base, nn.Module):
"""Deep Graph Clustering"""
def __init__(
self,
in_feats: int,
out_feats_list: List[int],
n_clusters: int,
classifier_hidden_list: List[int] = None,
aggregator_type: str = "gcn",
bias: bool = True,
k: int = 20,
tau: float = 0.9999,
encoder_act: List[str] = None,
classifier_act: List[str] = None,
dropout: float = 0.0,
n_epochs: int = 1000,
n_pretrain_epochs: int = 800,
lr: float = 0.01,
l2_coef: float = 0.0,
early_stopping_epoch: int = 20,
model_filename: str = "dgc_mlp_gsl",
):
super().__init__()
nn.Module.__init__(self)
self.n_clusters = n_clusters
self.n_epochs = n_epochs
self.n_pretrain_epochs = n_pretrain_epochs
self.early_stopping_epoch = early_stopping_epoch
self.model_filename = model_filename
self.n_layers = len(out_feats_list)
self.k = k
self.tau = tau
self.norm = None
self.pos_weight = None
self.device = None
self.batch_size = None
self.encoder = MultiLayerGNN(
in_feats=in_feats,
out_feats_list=out_feats_list,
aggregator_type=aggregator_type,
bias=bias,
activation=encoder_act,
dropout=dropout,
)
self.decoder = InnerProductDecoder()
self.classifier = MultiLayerDNN(
in_feats=out_feats_list[-1],
out_feats_list=[n_clusters]
if classifier_hidden_list is None else classifier_hidden_list,
activation=["softmax"]
if classifier_act is None else classifier_act,
)
self.optimizer = torch.optim.Adam(
self.parameters(),
lr=lr,
weight_decay=l2_coef,
)
for module in self.modules():
init_weights(module)
[docs] def load_best_model(self, device: torch.device) -> None:
self, _, _, _ = load_model(self.model_filename, self, self.optimizer)
self.to(device)
[docs] def forward(self, blocks):
z = self.encoder(blocks, blocks[0].srcdata["feat"])
preds = self.classifier(z)
adj_hat = self.decoder(preds)
return preds, adj_hat
[docs] def loss(self, adj_hat: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
return self.norm * F.binary_cross_entropy_with_logits(
adj_hat.view(-1),
adj.view(-1),
pos_weight=self.pos_weight,
)
[docs] def pretrain(
self,
data_loader: NaiveDataLoader,
adj_label: torch.Tensor,
) -> None:
cnt_wait = 0
best = 1e9
print("Train Encoder Start.")
for epoch in range(self.n_pretrain_epochs):
self.train()
loss_epoch = 0
for _, ((_, _, blocks), ) in enumerate(data_loader):
self.optimizer.zero_grad()
_, adj_hat = self.forward(blocks)
loss = self.loss(adj_hat, adj_label)
loss.backward()
self.optimizer.step()
loss_epoch += loss.item()
if epoch % 50 == 0:
print(
f"Pretrain Epoch: {epoch+1:04d}\t Loss:{loss_epoch / len(data_loader) :.8f}"
)
if round(loss_epoch, 5) < best:
best = loss_epoch
cnt_wait = 0
save_model(
self.model_filename,
self,
self.optimizer,
epoch,
loss_epoch,
)
else:
cnt_wait += 1
if cnt_wait == self.early_stopping_epoch:
print("Pretrain Encoder Done.")
break
self.load_best_model(self.device)
[docs] def learn_structure(self, preds: torch.Tensor) -> torch.Tensor:
q = torch.mm(preds, preds.t())
_, indices = q.topk(k=self.k, dim=-1)
mask = torch.zeros(q.shape).to(self.device)
mask[torch.arange(q.shape[0]).view(-1, 1), indices] = 1
adj_new = F.relu(mask * q)
return adj_new
[docs] def fit(
self,
graph: dgl.DGLGraph,
device: torch.device = torch.device("cpu"),
) -> None:
self.device = device
self.batch_size = graph.num_nodes()
adj_raw = graph.adj().to(device)
adj = adj_raw.to_dense()
for _ in range(self.n_epochs):
adj_sum = adj.sum()
# (|V|**2 - |E|) / |E|
self.pos_weight = torch.FloatTensor([
float(adj.shape[0] * adj.shape[0] - adj_sum) / adj_sum
]).to(device)
# |V|**2 / (2 * ((|V|**2 - |E|)))
self.norm = (adj.shape[0] * adj.shape[0] / float(
(adj.shape[0] * adj.shape[0] - adj_sum) * 2))
data_loader = NaiveDataLoader(
graph=graph,
batch_size=self.batch_size,
n_layers=self.n_layers,
device=device,
)
self.to(device)
self.pretrain(data_loader=data_loader, adj_label=adj)
preds = self.get_embedding(graph, device)
# res = preds.argmax(dim=1).cpu().numpy()
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(label, res)
# print("\n"
# f"ARI:{ARI_score}\n"
# f"NMI:{ NMI_score}\n"
# f"ACC:{ACC_score}\n"
# f"Micro F1:{Micro_F1_score}\n"
# f"Macro F1:{Macro_F1_score}\n")
adj = self.tau * adj + (1 - self.tau) * self.learn_structure(preds)
graph_new = torch_sparse_to_dgl_graph(
sparse_mx_to_torch_sparse_tensor(
sp.csr_matrix(adj.cpu().numpy())))
graph_new.ndata["feat"] = graph.ndata["feat"].to(device)
graph = graph_new
[docs] def get_embedding(
self,
graph: dgl.DGLGraph,
device: torch.device = torch.device("cpu"),
) -> torch.Tensor:
"""Get the embeddings (graph or node level).
Returns:
(torch.Tensor): embedding.
"""
self.load_best_model(device=device)
data_loader = NaiveDataLoader(
graph=graph,
batch_size=self.batch_size,
n_layers=self.n_layers,
device=device,
)
embedding = []
for _, ((_, _, blocks), ) in tqdm(
enumerate(data_loader),
desc="Inference:",
total=len(data_loader),
):
with torch.no_grad():
preds, _ = self.forward(blocks)
embedding.extend(preds.cpu().numpy())
return torch.Tensor(embedding).to(device)
[docs] def get_memberships(
self,
graph: dgl.DGLGraph,
device: torch.device = torch.device("cpu"),
) -> np.ndarray:
"""Get memberships
Returns:
np.ndarray: memberships
"""
return self.get_embedding(graph, device).argmax(dim=1).cpu().numpy()
# # for test only
# if __name__ == '__main__':
# from utils import load_data
# from utils.evaluation import evaluation
# from utils import set_device
# from utils import set_seed
# import scipy.sparse as sp
# import time
# device = set_device('0')
# set_seed(4096)
# graph, label, n_clusters = load_data(dataset_name='Cora')
# features = graph.ndata["feat"]
# adj_csr = graph.adj_external(scipy_fmt='csr')
# edges = graph.edges()
# features_lil = sp.lil_matrix(features)
# model = DGC(
# in_feats=features.shape[1],
# out_feats_list=[500],
# classifier_hidden_list=[7],
# aggregator_type='mean',
# n_clusters=n_clusters,
# bias=True,
# k=20,
# tau=0.9999,
# encoder_act=['relu'],
# classifier_act=['softmax'],
# dropout=0.0,
# n_epochs=1000,
# n_pretrain_epochs=800,
# lr=0.001,
# l2_coef=0.0,
# early_stopping_epoch=20,
# model_filename='dgc_mlp_gsl',
# )
# model.fit(graph=graph, device=device)
# res = model.get_memberships(graph, device)
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(label, res)
# print("\n"
# f"ARI:{ARI_score}\n"
# f"NMI:{ NMI_score}\n"
# f"ACC:{ACC_score}\n"
# f"Micro F1:{Micro_F1_score}\n"
# f"Macro F1:{Macro_F1_score}\n")