"""Embedding By GMI
Adapted From: https://github.com/zpeng27/GMI
"""
from typing import Tuple
import numpy as np
import scipy.sparse as sp
import torch
from torch import nn
from ...module import GMI
from ...utils import get_repeat_shuffle_nodes_list
from ...utils import normalize_feature
from ...utils import sparse_mx_to_torch_sparse_tensor
from ...utils import symmetrically_normalize_adj
[docs]def mi_loss_jsd(pos: torch.Tensor, neg: torch.Tensor) -> torch.Tensor:
"""Jensen-Shannon MI Estimator
Args:
pos (torch.Tensor): :math:`D_w(h_i, x_i) or D_w(h_i, x_j)`.
neg (torch.Tensor): :math:`D_w(h_i, x'_i) or D_w(h_i, x'_j)`.
Returns:
(torch.Tensor): JSD loss.
.. math::
& sp(-D_w(h_i,x_i))+E(sp(D_w(h_i,x'_i)))\\\\
& \\textbf{or} \\\\
& sp(-D_w(h_i,x_j))+E(sp(D_w(h_i,x'_j))). \\\\
"""
e_pos = torch.mean(torch.log(1 + torch.exp(-pos)))
e_neg = torch.mean(torch.mean(torch.log(1 + torch.exp(neg)), 0))
return e_pos + e_neg
[docs]def reconstruct_loss(pred: torch.Tensor, gnd: torch.Tensor) -> torch.Tensor:
"""Loss of Rebuilt Adj
Args:
pred (torch.Tensor): :math:`w_{ij}`.
gnd (torch.Tensor): :math:`a_{ij}`.
Returns:
(torch.Tensor): reconstruction loss.
.. math::
\\text{reconstruct}_{loss} =
& \\frac{n^2}{n^2 - |E|} * AVG(\\frac{-(n^2-|E|)}{|E|} *
a_{ij} * \\log(w_{ij} + e^{-10}) \\\\
& - (1 - a_{ij}) * \\log(1 - w_{ij} + e^{-10})).
"""
nodes_n = gnd.shape[0]
edges_n = np.sum(gnd) / 2
weight1 = (nodes_n * nodes_n - edges_n) * 1.0 / edges_n
weight2 = nodes_n * nodes_n * 1.0 / (nodes_n * nodes_n - edges_n)
gnd = torch.FloatTensor(gnd).cuda()
temp1 = gnd * torch.log(pred + (1e-10)) * (-weight1)
temp2 = (1 - gnd) * torch.log(1 - pred + (1e-10))
return torch.mean(temp1 - temp2) * weight2
[docs]def preprocess_adj(
adj_orig: sp.csr_matrix) -> Tuple[torch.Tensor, torch.Tensor]:
"""Preprocess of Adjacency Matrix for Row Avarage and Self Loop
Args:
adj_orig (<class 'scipy.sparse.csr.csr_matrix'>): input origin adjacency matrix.
Returns:
adj_orig, adj_target (<class 'scipy.sparse.csr.csr_matrix'>, <class 'numpy.matrix'>):
row avarage and self loop adj
"""
adj_dense = adj_orig.toarray()
adj_row_avg = 1.0 / np.sum(adj_dense, axis=1)
adj_row_avg[np.isnan(adj_row_avg)] = 0.0
adj_row_avg[np.isinf(adj_row_avg)] = 0.0
adj_dense = adj_dense * 1.0
for i in range(adj_orig.shape[0]):
adj_dense[i] = adj_dense[i] * adj_row_avg[i]
adj_orig = sp.csr_matrix(adj_dense, dtype=np.float32)
adj_target = adj_dense + np.eye(adj_dense.shape[0])
return adj_orig, adj_target
[docs]class GMIEmbed(nn.Module):
"""GMI Embedding
Args:
in_features (int): input feature dimension.
hidden_units (int, optional): hidden units size of gcn. Defaults to 512.
n_epochs (int, optional): number of embedding training epochs. Defaults to 550.
early_stopping_epoch (int, optional): early stopping threshold. Defaults to 20.
lr (float, optional): learning rate. Defaults to 0.001.
l2_coef (float, optional): weight decay. Defaults to 0.0.
alpha (float, optional): parameter for :math:`I(h_i; x_i)`. Defaults to 0.8.
beta (float, optional): parameter for :math:`I(h_i; x_j)`. Defaults to 1.0.
gamma (float, optional): parameter for :math:`I(w_ij; a_ij)`. Defaults to 1.0.
activation (str, optional): activation of gcn layer. Defaults to "prelu".
"""
def __init__(
self,
in_features: int,
hidden_units: int = 512,
n_epochs: int = 550,
early_stopping_epoch: int = 20,
lr: float = 0.001,
l2_coef: float = 0.0,
alpha: float = 0.8,
beta: float = 1.0,
gamma: float = 1.0,
activation: str = "prelu",
gcn_depth: int = 2,
) -> None:
super().__init__()
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.n_epochs = n_epochs
self.early_stopping_epoch = early_stopping_epoch
self.features_norm = None
self.adj_orig = None
self.adj_norm = None
self.adj_target = None
self.model = GMI(in_features,
hidden_units,
gcn_depth=gcn_depth,
activation=activation)
self.optimizer = torch.optim.Adam(self.model.parameters(),
lr=lr,
weight_decay=l2_coef)
[docs] def calc_loss(
self,
mi_pos: torch.Tensor,
mi_neg: torch.Tensor,
local_mi_pos: torch.Tensor,
local_mi_neg: torch.Tensor,
adj_rebuilt: torch.Tensor,
) -> torch.Tensor:
"""Calculate Loss
Args:
mi_pos (torch.Tensor): :math:`D_w(h_i, x_i)`.
mi_neg (torch.Tensor): :math:`D_w(h_i, x'_i)`.
local_mi_pos (torch.Tensor): :math:`D_w(h_i, x_j)`.
local_mi_neg (torch.Tensor): :math:`D_w(h_i, x'_j)`.
adj_rebuilt (torch.Tensor): :math:`w_{ij}`
Returns:
(torch.Tensor): loss.
.. math::
loss = & \\alpha * sp(-D_w(h_i,x_i))+E(sp(D_w(h_i,x'_i))) \\\\
& + \\beta * sp(-D_w(h_i,x_j))+E(sp(D_w(h_i,x'_j))) \\\\
& + \\gamma * \\text{reconstruct}_{loss} \\\\
"""
return (self.alpha * mi_loss_jsd(mi_pos, mi_neg) +
self.beta * mi_loss_jsd(local_mi_pos, local_mi_neg) +
self.gamma * reconstruct_loss(adj_rebuilt, self.adj_target))
[docs] def forward(self, neg_sample_list: torch.Tensor) -> torch.Tensor:
"""Forward Propagation
Args:
neg_sample_list (torch.Tensor): negative sample list.
Returns:
torch.Tensor: loss.
"""
mi_pos, mi_neg, local_mi_pos, local_mi_neg, adj_rebuilt = self.model(
self.features_norm, self.adj_orig, self.adj_norm, neg_sample_list)
return self.calc_loss(mi_pos, mi_neg, local_mi_pos, local_mi_neg,
adj_rebuilt)
[docs] def fit(
self,
features: sp.lil_matrix,
adj_orig: sp.csr_matrix,
neg_list_num: int = 5,
) -> None:
"""Fit for Specific Graph
Args:
features (sp.lil_matrix): 2D sparse features.
adj_orig (sp.csr_matrix): 2D sparse adj.
neg_list_num (int, optional): negative sample times. Defaults to 5.
"""
self.features_norm = torch.FloatTensor(
normalize_feature(features)[np.newaxis])
self.adj_norm = sparse_mx_to_torch_sparse_tensor(
symmetrically_normalize_adj(adj_orig + sp.eye(adj_orig.shape[0])))
self.adj_orig, self.adj_target = preprocess_adj(adj_orig)
if torch.cuda.is_available():
print("GPU available: GMI Embedding Using CUDA")
self.model.cuda()
self.features_norm = self.features_norm.cuda()
self.adj_norm = self.adj_norm.cuda()
best = 1e9
cnt_wait = 0
for epoch in range(self.n_epochs):
self.model.train()
self.optimizer.zero_grad()
neg_sample_list = get_repeat_shuffle_nodes_list(
adj_orig.shape[0], neg_list_num)
loss = self.forward(neg_sample_list)
print(f"Epoch:{epoch+1} Loss:{loss}")
if loss < best:
best = loss
cnt_wait = 0
torch.save(self.model.state_dict(), "best_gmi.pkl")
else:
cnt_wait += 1
if cnt_wait == self.early_stopping_epoch:
print("Early stopping!")
break
loss.backward()
self.optimizer.step()
[docs] def set_features_norm(self, features_norm) -> None:
"""Set the features row normalized
Args:
features_norm (torch.Tensor): normalized 3D features tensor in shape of [1, xx, xx]
"""
self.features_norm = features_norm
[docs] def set_adj_norm(self, adj_norm) -> None:
"""Set the adjacency symmetrically normalized
Args:
adj_norm (torch.Tensor): symmetrically normalized 2D adjacency tensor
"""
self.adj_norm = adj_norm
[docs] def get_features_norm(self) -> torch.Tensor:
"""Get the features row normalized
Returns:
features_norm (torch.Tensor): normalized 3D features tensor in shape of [1, xx, xx]
"""
return self.features_norm
[docs] def get_adj_norm(self) -> torch.Tensor:
"""Get the adjacency symmetrically normalized
Returns:
adj_norm (torch.Tensor): symmetrically normalized 2D adjacency tensor
"""
return self.adj_norm
[docs] def get_embedding(self) -> torch.Tensor:
"""Get the embeddings (graph or node level).
Returns:
(torch.Tensor): embedding.
"""
self.model.load_state_dict(torch.load("best_gmi.pkl"))
return self.model.get_embedding(self.features_norm, self.adj_norm)