"""
Contrastive Multi-View Representation Learning on Graphs
https://arxiv.org/abs/2006.05582
"""
import numpy as np
import scipy.sparse as sp
import torch
from sklearn.preprocessing import MinMaxScaler
from torch import nn
from ...module import BATCH_GCN
from ...module import DiscMVGRL
from ...utils import compute_ppr
from ...utils import normalize_feature
from ...utils import save_model
from ...utils import sk_clustering
from ...utils import sparse_mx_to_torch_sparse_tensor
from ...utils import symmetrically_normalize_adj
# from utils import load_model
# Borrowed from https://github.com/PetarV-/DGI
[docs]class Readout(nn.Module):
"""read out"""
[docs] @staticmethod
def forward(seq, msk):
"""Forward Propagation
Args:
seq (torch.Tensor): features tensor.
msk (torch.Tensor): node mask.
Returns:
(torch.Tensor): graph-level representation
"""
if msk is None:
return torch.mean(seq, 1)
msk = torch.unsqueeze(msk, -1)
return torch.mean(seq * msk, 1) / torch.sum(msk)
[docs]class MVGRL(nn.Module):
"""MVGRL:Contrastive Multi-View Representation Learning on Graphs
Args:
in_feats (int): Input feature size.
n_clusters (int): Num of clusters.
n_h (int,optional): hidden units dimension. Defaults to 256.
model_filename (str,optional): Path to store model parameters. Defaults to 'mvgrl'.
sparse (bool,optional): Use sparse tensor. Defaults to False.
nb_epochs (int,optional): Maximum training epochs. Defaults to 3000.
patience (int,optional): Early stopping patience. Defaults to 20.
lr (float,optional): Learning rate. Defaults to 0.001.
weight_decay (float,optional): Weight decay. Defaults to 0.0.
sample_size (int,optional): Sample size. Defaults to 2000.
batch_size (int,optional): Batch size. Defaults to 4.
dataset (str,optional): Dataset. Defaults to 'Citeseer'.
"""
def __init__(
self,
in_feats: int,
n_clusters: int,
n_h: int = 512,
model_filename: str = "mvgrl",
sparse: bool = False,
nb_epochs: int = 3000,
patience: int = 20,
lr: float = 0.001,
weight_decay: float = 0.0,
sample_size: int = 2000,
batch_size: int = 4,
dataset: str = "Citeseer",
):
super().__init__()
self.n_clusters = n_clusters
self.model_filename = model_filename
self.sparse = sparse
self.nb_epochs = nb_epochs
self.patience = patience
self.lr = lr
self.weight_decay = weight_decay
self.sample_size = sample_size
self.batch_size = batch_size
self.dataset = dataset
self.adj = None
self.diff = None
self.features = None
self.optimizer = None
self.msk = None
self.gcn1 = BATCH_GCN(in_feats, n_h)
self.gcn2 = BATCH_GCN(in_feats, n_h)
self.read = Readout()
self.sigm = nn.Sigmoid()
self.disc = DiscMVGRL(n_h)
[docs] def forward(self, seq1, seq2, adj, diff, sparse, msk):
"""Forward Propagation
Args:
seq1 (torch.Tensor): features of raw graph
seq2 (torch.Tensor): shuffle features of diffuse graph
adj (torch.Tensor): adj matrix of raw graph
diff (torch.Tensor): ppr matrix of diffuse graph
sparse (bool): if sparse
msk (torch.Tensor): mask node
Returns:
ret (torch.Tensor): probability of positive or negtive node
h_1 (torch.Tensor): node embedding of raw graph by one gcn layer
h_2 (torch.Tensor): node embedding of diffuse graph by one gcn layer
"""
h_1 = self.gcn1(seq1, adj, sparse)
c_1 = self.read(h_1, msk)
c_1 = self.sigm(c_1)
h_2 = self.gcn2(seq1, diff, sparse)
c_2 = self.read(h_2, msk)
c_2 = self.sigm(c_2)
h_3 = self.gcn1(seq2, adj, sparse)
h_4 = self.gcn2(seq2, diff, sparse)
ret = self.disc(c_1, c_2, h_1, h_2, h_3, h_4)
return ret, h_1, h_2
[docs] def fit(self, adj_csr, features):
"""Fitting
Args:
adj_csr (sp.lil_matrix): adj sparse matrix.
features (torch.Tensor): features.
"""
# adj_csr = graph.adj_external(scipy_fmt='csr')
self.adj = adj_csr.toarray()
self.diff = compute_ppr(self.adj, 0.2)
# self.features = graph.ndata["feat"].numpy()
self.features = features.numpy()
if self.dataset == "Citeseer":
self.features = sp.lil_matrix(self.features)
self.features = normalize_feature(self.features)
epsilons = [1e-5, 1e-4, 1e-3, 1e-2]
avg_degree = np.sum(self.adj) / self.adj.shape[0]
epsilon = epsilons[np.argmin([
abs(avg_degree -
np.argwhere(self.diff >= e).shape[0] / self.diff.shape[0])
for e in epsilons
])]
self.diff[self.diff < epsilon] = 0.0
scaler = MinMaxScaler()
scaler.fit(self.diff)
self.diff = scaler.transform(self.diff)
self.adj = symmetrically_normalize_adj(
self.adj + sp.eye(self.adj.shape[0])).todense()
ft_size = self.features.shape[1]
lbl_1 = torch.ones(self.batch_size, self.sample_size * 2)
lbl_2 = torch.zeros(self.batch_size, self.sample_size * 2)
lbl = torch.cat((lbl_1, lbl_2), 1)
self.optimizer = torch.optim.Adam(self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay)
if torch.cuda.is_available():
self.cuda()
lbl = lbl.cuda()
b_xent = nn.BCEWithLogitsLoss()
cnt_wait = 0
best = 1e9
for epoch in range(self.nb_epochs):
idx = np.random.randint(0,
self.adj.shape[-1] - self.sample_size + 1,
self.batch_size)
ba, bd, bf = [], [], []
for i in idx:
ba.append(self.adj[i:i + self.sample_size,
i:i + self.sample_size])
bd.append(self.diff[i:i + self.sample_size,
i:i + self.sample_size])
bf.append(self.features[i:i + self.sample_size])
ba = np.array(ba).reshape(self.batch_size, self.sample_size,
self.sample_size)
bd = np.array(bd).reshape(self.batch_size, self.sample_size,
self.sample_size)
bf = np.array(bf).reshape(self.batch_size, self.sample_size,
ft_size)
if self.sparse:
ba = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(ba))
bd = sparse_mx_to_torch_sparse_tensor(sp.coo_matrix(bd))
else:
ba = torch.FloatTensor(ba)
bd = torch.FloatTensor(bd)
bf = torch.FloatTensor(bf)
idx = np.random.permutation(self.sample_size)
shuf_fts = bf[:, idx, :]
if torch.cuda.is_available():
bf = bf.cuda()
ba = ba.cuda()
bd = bd.cuda()
shuf_fts = shuf_fts.cuda()
self.train()
self.optimizer.zero_grad()
logits, __, __ = self.forward(bf, shuf_fts, ba, bd, self.sparse,
None)
loss = b_xent(logits, lbl)
loss.backward()
self.optimizer.step()
print(f"Epoch: {epoch}, Loss: {loss.item()}")
if loss < best:
best = loss
cnt_wait = 0
save_model(self.model_filename, self, self.optimizer, epoch,
loss.item())
# torch.save(self.state_dict(), 'model.pkl')
else:
cnt_wait += 1
if cnt_wait == self.patience:
print("Early stopping!")
break
[docs] def get_embedding(self):
"""Get the embeddings (graph or node level).
Returns:
(torch.Tensor): embedding of each node.
(torch.Tensor): embedding of graph representations
"""
# model, _, _, _ = load_model(self.model_filename, self, self.optimizer)
adj = torch.FloatTensor(self.adj[np.newaxis])
diff = torch.FloatTensor(self.diff[np.newaxis])
features = torch.FloatTensor(self.features[np.newaxis])
adj = adj.cuda()
diff = diff.cuda()
features = features.cuda()
h_1 = self.gcn1(features, adj, self.sparse)
c = self.read(h_1, self.msk)
h_2 = self.gcn2(features, diff, self.sparse)
return (h_1 + h_2).detach(), c.detach()
[docs] def get_memberships(self, ):
"""Get memberships
Returns:
np.ndarray: memberships
"""
pred, _ = self.get_embedding()
return sk_clustering(torch.squeeze(pred, 0).cpu(),
self.n_clusters,
name="kmeans")
# # for test only
# if __name__ == '__main__':
# from utils import load_data
# from utils.evaluation import evaluation,best_mapping
# from utils import set_device
# from utils import set_seed
# import scipy.sparse as sp
# import time
# import pandas as pd
# set_seed(4096)
# device = set_device('1')
# dataset = 'ACM'
# graph, label, n_clusters = load_data(
# dataset_name=dataset,
# directory='./data',
# )
# print(graph)
# features = graph.ndata["feat"]
# start_time = time.time()
# model = MVGRL(in_feats=features.shape[1],
# n_clusters=n_clusters,
# n_h=512,
# lr=0.001,
# dataset=dataset)
# model.fit(graph=graph)
# res = model.get_memberships()
# elapsed_time = time.time() - start_time
# (
# ARI_score,
# NMI_score,
# ACC_score,
# Micro_F1_score,
# Macro_F1_score,
# ) = evaluation(label, res)
# print("\n"
# f"Elapsed Time:{elapsed_time:.2f}s\n"
# f"ARI:{ARI_score}\n"
# f"NMI:{ NMI_score}\n"
# f"ACC:{ACC_score}\n"
# f"Micro F1:{Micro_F1_score}\n"
# f"Macro F1:{Macro_F1_score}\n")
# labels_true, labels_pred = best_mapping(label.cpu().numpy(), res)
# df_res = pd.DataFrame({'label':labels_true,'pred':labels_pred})
# df_res.to_pickle(f'./tmp/MVGRL_{dataset}_pred.pkl')
# print('write to',f'./tmp/MVGRL_{dataset}_pred.pkl')