Source code for egc.model.graph_clustering.disjoint.gdcl

"""
Graph Debiased Contrastive Learning with Joint Representation Clustering
https://www.ijcai.org/proceedings/2021/0473.pdf
"""
import random

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn.functional as F
from sklearn.cluster import KMeans
from sklearn.preprocessing import MinMaxScaler
from torch import nn
from torch.nn.parameter import Parameter

from ....utils import compute_ppr
from ....utils import get_checkpoint_path
from ....utils import normalize_feature
from ....utils import save_model
from ....utils import sparse_mx_to_torch_sparse_tensor
from ....utils import symmetrically_normalize_adj
from ....utils.evaluation import evaluation
from ...node_embedding.mvgrl import MVGRL

# pylint:disable=too-many-branches,too-many-statements


# Borrowed from https://github.com/PetarV-/DGI
[docs]class Readout(nn.Module): """read out"""
[docs] @staticmethod def forward(seq, msk): """Forward Propagation Args: seq (torch.Tensor): features tensor. msk (torch.Tensor): node mask. Returns: (torch.Tensor): graph-level representation """ if msk is None: return torch.mean(seq, 1) msk = torch.unsqueeze(msk, -1) return torch.mean(seq * msk, 1) / torch.sum(msk)
[docs]class GDCL(nn.Module): """GDCL: Graph Debiased Contrastive Learning with Joint Representation Clustering Args: in_feats (int): Input feature size. n_clusters (int): Num of clusters. n_h (int): hidden units dimension. Defaults to 512. nb_epochs: epoch number of GDCL . Defaults to 1500. lr: learning rate of GDCL. Defaults to 0.00005. alpha: alpha parameter of distribution. Defaults to 0.0001. mask_num: mask number. Defaults to 100. batch_size: batch size of GDCL. Defaults to 4. update_interval: update interval of GDCL. Defaults to 10. model_filename: model filename of GDCL. Defaults to 'gdcl'. beta: balance factor. Defaults to 10e-4. weight_decay: weight decay of GDCL. Defaults to 0.0. pt_n_h:hidden units dimension of pretrained MVGRL. Defaults to 512. pt_model_filename: model filename of pretrained MVGRL. Defaults to 'mvgrl'. pt_nb_epochs: epoch number of pretrained MVGRL. Defaults to 3000. pt_patience: patience of pretrained MVGRL. Defaults to 20. pt_lr: learning rate of pretrained MVGRL. Defaults to 0.001. pt_weight_decay: weight decay of pretrained MVGRL. Defaults to 0.0. pt_sample_size: sample size of pretrained MVGRL. Defaults to 2000. pt_batch_size: batch size of pretrained MVGRL. Defaults to 4. sparse: if sparse. Defaults to False. dataset: dataset name. Defaults to 'Citeseer'. device: device. Defaults to torch.device('cpu'). """ def __init__( self, in_feats, n_clusters, n_h: int = 512, nb_epochs: int = 1500, lr: float = 0.00005, alpha=0.0001, mask_num: int = 100, batch_size: int = 4, update_interval: int = 10, model_filename: str = "gdcl", beta: float = 10e-4, weight_decay: float = 0.0, pt_n_h: int = 512, pt_model_filename: str = "mvgrl", pt_nb_epochs: int = 3000, pt_patience: int = 20, pt_lr: float = 0.001, pt_weight_decay: float = 0.0, pt_sample_size: int = 2000, pt_batch_size: int = 4, sparse: bool = False, dataset: str = "Citeseer", device: torch.device = torch.device("cpu"), ): super().__init__() self.n_clusters = n_clusters self.nb_epochs = nb_epochs self.lr = lr self.sparse = sparse self.alpha = alpha self.pretrain_path = get_checkpoint_path(pt_model_filename) self.mask_num = mask_num self.dataset = dataset self.nb_epochs = nb_epochs self.batch_size = batch_size self.update_interval = update_interval self.model_filename = model_filename self.beta = beta self.weight_decay = weight_decay self.device = device self.adj = None self.diff = None self.features = None self.optimizer = None self.mvg = MVGRL( in_feats=in_feats, n_clusters=n_clusters, n_h=pt_n_h, model_filename=pt_model_filename, sparse=sparse, nb_epochs=pt_nb_epochs, patience=pt_patience, lr=pt_lr, weight_decay=pt_weight_decay, sample_size=pt_sample_size, batch_size=pt_batch_size, dataset=dataset, ) # cluster layer self.cluster_layer = Parameter(torch.Tensor(n_clusters, n_h)) torch.nn.init.xavier_normal_(self.cluster_layer.data)
[docs] def pretrain(self, graph): """Fitting Args: graph (dgl.DGLGraph): graph. """ print("pretrained MVGRL starting...") self.mvg.fit(adj_csr=graph.adj_external(scipy_fmt="csr"), features=graph.ndata["feat"]) print("pretrained MVGRL ending...")
[docs] def embed(self, seq, adj, diff, sparse): """Embed. Args: seq (tensor.Tensor): features of raw graph adj (tensor.Tensor): adj matrix of raw graph diff (tensor.Tensor): ppr matrix of diffuse graph sparse (bool): if sparse Returns: (tensor.Tensor): node embedding """ h_1 = self.mvg.gcn1(seq, adj, sparse) h = self.mvg.gcn2(seq, diff, sparse) return ((h + h_1)).detach()
[docs] def forward(self, bf, mask_fts, bd, sparse): """Forward Propagation Args: bf (tensor.Tensor): features of raw graph mask_fts (tensor.Tensor): mask features bd (tensor.Tensor): ppr matrix of diffuse graph sparse (bool): if sparse Returns: h_mask (tensor.Tensor): node embedding of mask features graph h (tensor.Tensor): node embedding of raw graph q (tensor.Tensor): soft assignment """ h_mask = self.mvg.gcn2(mask_fts, bd, sparse)[0].unsqueeze(0) h = self.mvg.gcn2(bf, bd, sparse)[0].unsqueeze(0) # cluster q = 1.0 / ( 1.0 + torch.sum( torch.pow( h.reshape(-1, h.shape[2]).unsqueeze(1) - self.cluster_layer, 2), 2, ) / self.alpha ) # h.reshape(-1,h.shape[2]).unsqueeze(1)-self.cluster_layer q = q.pow((self.alpha + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() return h_mask, h, q
[docs] def fit(self, graph, labels): """Fitting Args: graph (dgl.DGLGraph): graph. labels (tensor.Tensor): labels of each node """ adj_csr = graph.adj_external(scipy_fmt="csr") self.adj = adj_csr.toarray() self.diff = compute_ppr(self.adj, 0.2) self.features = graph.ndata["feat"].numpy() if self.dataset == "Citeseer": self.features = sp.lil_matrix(self.features) self.features = normalize_feature(self.features) epsilons = [1e-5, 1e-4, 1e-3, 1e-2] avg_degree = np.sum(self.adj) / self.adj.shape[0] epsilon = epsilons[np.argmin([ abs(avg_degree - np.argwhere(self.diff >= e).shape[0] / self.diff.shape[0]) for e in epsilons ])] self.diff[self.diff < epsilon] = 0.0 scaler = MinMaxScaler() scaler.fit(self.diff) self.diff = scaler.transform(self.diff) self.adj = symmetrically_normalize_adj( self.adj + sp.eye(self.adj.shape[0])).todense() ft_size = self.features.shape[1] sample_size = self.features.shape[0] labels = torch.LongTensor(labels) self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) self.pretrain(graph) if self.sparse: self.adj = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix( self.adj)) self.diff = sparse_mx_to_torch_sparse_tensor( sp.coo_matrix(self.diff)) features_array = self.features diff_array = self.diff self.features = torch.FloatTensor(self.features[np.newaxis]) self.adj = torch.FloatTensor(self.adj[np.newaxis]) self.diff = torch.FloatTensor(self.diff[np.newaxis]) self.features = self.features.to(self.device) self.adj = self.adj.to(self.device) self.diff = self.diff.to(self.device) # obtain features of positive samples features_mask = self.features # [1,n,d] for i in range(features_mask.shape[1]): idx = random.sample(range(1, features_mask.shape[2]), self.mask_num) features_mask[0][i][idx] = 0 # feature random mask 0 features_mask_array = np.array(features_mask.squeeze(0).cpu()) # cluster parameter initiate h2 = self.mvg.gcn2(self.features, self.diff, self.sparse) kmeans = KMeans(n_clusters=self.n_clusters) y_pred = kmeans.fit_predict(h2.data.squeeze().cpu().numpy()) self.cluster_layer.data = torch.tensor(kmeans.cluster_centers_).to( self.device) self.train() acc_clu = 0 kl_loss = 0 loss = 0 for epoch in range(self.nb_epochs): idx = np.random.randint(0, self.adj.shape[-1] - sample_size + 1, self.batch_size) bd, bf, bf_mask = [], [], [] for i in idx: bd.append(diff_array[i:i + sample_size, i:i + sample_size]) bf.append(features_array[i:i + sample_size]) bf_mask.append(features_mask_array[i:i + sample_size]) bd = np.array(bd).reshape(self.batch_size, sample_size, sample_size) bf = np.array(bf).reshape(self.batch_size, sample_size, ft_size) bf_mask = np.array(bf_mask).reshape(self.batch_size, sample_size, ft_size) if self.sparse: bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd)) else: bd = torch.FloatTensor(bd) bf = torch.FloatTensor(bf) bf_mask = torch.FloatTensor(bf_mask) bf = bf.to(self.device) bd = bd.to(self.device) bf_mask = bf_mask.to(self.device) if epoch % self.update_interval == 0: _, _, tmp_q = self.forward(bf, bf_mask, bd, self.sparse) # update target distribution p tmp_q = tmp_q.data p = target_distribution(tmp_q) # evaluate clustering performance y_pred = self.get_memberships() ( ARI_score, NMI_score, _, ACC_score, _, _, _, ) = evaluation(np.array(labels.cpu()), y_pred) print("ACC_score:", ACC_score) if ACC_score > acc_clu: acc_clu = ACC_score _ = NMI_score _ = ARI_score save_model( self.model_filename, self, self.optimizer, epoch, loss.item() if loss != 0 else 0, ) h_mask, h_2_sour, q = self.forward(bf, bf_mask, bd, self.sparse) kl_loss = F.kl_div(q.log(), p) temperature = 0.5 y_sam = torch.LongTensor(y_pred) # --------------- compute pos sample results --------------- neg_size = 1000 class_sam = [] for m in range(np.max(y_pred) + 1): class_del = torch.ones(int(sample_size), dtype=bool) class_del[np.where(y_sam.cpu() == m)] = 0 class_neg = torch.arange(sample_size).masked_select(class_del) # FIXME: Sample larger than population neg_sam_id = random.sample( range(0, class_neg.shape[0]), int(neg_size), ) class_sam.append(class_neg[neg_sam_id]) # [n_class,neg_size] out = (h_2_sour).squeeze() # shape: [sample_size,d] neg = torch.exp(torch.mm(out, out.t().contiguous()) / temperature) # shape: [sample_size,sample_size] neg_samp = torch.zeros( neg.shape[0], int(neg_size)) # shape: [sample_size,neg_size] for n in range(np.max(y_pred) + 1): neg_samp[np.where(y_sam.cpu() == n)] = neg.cpu().index_select( 1, class_sam[n])[np.where(y_sam.cpu() == n)] neg_samp = neg_samp.cuda() Ng = neg_samp.sum(dim=-1) # ---------------- compute pos sample results -------------- pos_size = 10 class_sam_pos = [] for m in range(np.max(y_pred) + 1): class_del = torch.ones(int(sample_size), dtype=bool) class_del[np.where(y_sam.cpu() != m)] = 0 class_pos = torch.arange(sample_size).masked_select(class_del) pos_sam_id = random.sample( range(0, class_pos.shape[0]), int(pos_size)) # BUG number of pos samples < pos_size class_sam_pos.append( class_neg[pos_sam_id] ) # BUG why class_neg .... class_sam_pos shape:[pos_size,d] out = h_2_sour.squeeze() pos = torch.exp(torch.mm(out, out.t().contiguous())) pos_samp = torch.zeros( pos.shape[0], int(pos_size)) # shape:[sample_size,pos_size] for n in range(np.max(y_pred) + 1): pos_samp[np.where(y_sam.cpu() == n)] = pos.cpu().index_select( 1, class_sam_pos[n])[np.where(y_sam.cpu() == n)] pos_samp = pos_samp.cuda() pos = pos_samp.sum(dim=-1) + torch.diag( torch.exp(torch.mm(out, (h_mask.squeeze()).t().contiguous()))) node_contra_loss_2 = (-torch.log(pos / (pos + Ng))).mean() loss = node_contra_loss_2 + self.beta * kl_loss self.optimizer.zero_grad() loss.backward() self.optimizer.step()
[docs] def get_embedding(self): """Get the embeddings (graph or node level). Returns: (torch.Tensor): embedding of each node. (torch.Tensor): embedding of graph representations """ h_2 = self.mvg.gcn2(self.features, self.diff, self.sparse) return h_2.detach()
[docs] def get_memberships(self): """Get memberships Returns: np.ndarray: memberships """ h = self.get_embedding() # cluster q = 1.0 / ( 1.0 + torch.sum( torch.pow( h.reshape(-1, h.shape[2]).unsqueeze(1) - self.cluster_layer, 2), 2, ) / self.alpha ) # h.reshape(-1,h.shape[2]).unsqueeze(1)-self.cluster_layer q = q.pow((self.alpha + 1.0) / 2.0) q = (q.t() / torch.sum(q, 1)).t() y_pred = q.detach().cpu().numpy().argmax(1) return y_pred
[docs]def target_distribution(q): """get target distribution P Args: q (torch.Tensor): Soft assignments Returns: torch.Tensor: target distribution P """ weight = q**2 / q.sum(0) return (weight.t() / weight.sum(1)).t()
# # for test only # if __name__ == '__main__': # from utils import load_data # from utils.evaluation import evaluation # from utils import set_device # from utils import set_seed # import scipy.sparse as sp # import time # set_seed(4096) # device = set_device('4') # graph, label, n_clusters = load_data( # dataset_name='Citeseer', # directory='./data', # ) # print(graph) # features = graph.ndata["feat"] # start_time = time.time() # model = GDCL( # in_feats=features.shape[1], # n_clusters=n_clusters, # device=device # ) # model.fit(graph=graph,labels=label) # res = model.get_memberships() # elapsed_time = time.time() - start_time # ( # ARI_score, # NMI_score, # ACC_score, # Micro_F1_score, # Macro_F1_score, # ) = evaluation(label, res) # print("\n" # f"Elapsed Time:{elapsed_time:.2f}s\n" # f"ARI:{ARI_score}\n" # f"NMI:{ NMI_score}\n" # f"ACC:{ACC_score}\n" # f"Micro F1:{Micro_F1_score}\n" # f"Macro F1:{Macro_F1_score}\n")