Source code for egc.utils.sampling
"""
Sample Method
"""
import multiprocessing
import random
from typing import Dict
from typing import List
from typing import Tuple
import numpy as np
import torch
[docs]def get_repeat_shuffle_nodes_list(n_nodes, sample_times):
"""Get Negative Sample Nodes List By Repeatable Shuffle
Args:
n_nodes (int): node number in all.
sample_times (int): sample times.
Returns:
(List): list of multiple repeatable nodes index shuffle lists.
"""
sample_list = []
for _ in range(sample_times):
sample_iter = []
i = 0
while True:
randnum = np.random.randint(0, n_nodes)
if randnum != i:
sample_iter.append(randnum)
i = i + 1
if len(sample_iter) == n_nodes:
break
sample_list.append(sample_iter)
return sample_list
[docs]def normal_reparameterize(mu: torch.Tensor,
logvar: torch.Tensor,
training: bool = True) -> torch.Tensor:
"""Reparameterization trick for normal distribution
Args:
mu (torch.Tensor): mu
logvar (torch.Tensor): logsigma
training (bool): isTraining
Returns:
(torch.Tensor)
"""
if training:
std = torch.exp(logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
return mu
######################################################################################
# START: This section of code is adapted from https://github.com/SamJia/CommunityGAN #
######################################################################################
[docs]def agm(x: np.ndarray) -> np.ndarray:
"""AGM probability
Args:
x (np.ndarray): 1-d array
Returns:
np.ndarray: AGM probability
"""
agm_x = 1 - np.exp(-x)
agm_x[np.isnan(agm_x)] = 0
return np.clip(agm_x, 1e-6, 1)
[docs]def choice(samples: List[int], weight: np.ndarray) -> int:
"""choose next node
Args:
samples (List[int]): neighbors
weight (np.ndarray): wights
Returns:
int: node chosen
"""
s = np.sum(weight)
target = random.random() * s
for si, wi in zip(samples, weight):
if target < wi:
return si
target -= wi
return samples[-1]
[docs]class CommunityGANSampling:
"""CommunityGAN Sampling
Args:
n_threads (int): cores of multiprocessing.
args (Tuple[int, int, bool]): root, n_sample, only_neg.
root (int): root node id
n_sample (int): num of motif sampled
only_neg (bool): only return negative samples
motif_size (int): motif size.
total_motifs (List[List[Tuple]]): list of all motifs indexed by node id.
theta_g (np.ndarray): node embedding of generator.
neighbor_set (Dict): neighbor set Dict indexed by node id.
"""
def __init__(
self,
n_threads: int,
args: Tuple[int, int, bool],
motif_size: int,
total_motifs: List[List[Tuple]],
theta_g: np.ndarray,
neighbor_set: Dict,
) -> None:
super().__init__()
self.n_threads = n_threads
self.args = args
self.motif_size = motif_size
self.total_motifs = total_motifs
self.theta_g = theta_g
self.neighbor_set = neighbor_set
[docs] def g_v(self, roots: List[int]) -> Tuple[int, List[int]]:
"""get next node
Args:
roots (List[int]): list of node sampled before
Returns:
Tuple[int, List[int]]: current_node, path walked
"""
g_v_v = self.theta_g[roots[0]]
for nid in roots[1:]:
g_v_v *= self.theta_g[nid]
current_node = roots[-1]
previous_nodes = set()
path = []
is_root = True
while True:
node_neighbor = (list({
neighbor
for root in roots
for neighbor in self.neighbor_set[root]
}) if is_root else list(self.neighbor_set[current_node]))
if len(node_neighbor) == 0:
return None, None
tmp_g = g_v_v if is_root else g_v_v * self.theta_g[current_node]
relevance_probability = agm(
np.sum(self.theta_g[node_neighbor] * tmp_g, axis=1))
next_node = choice(node_neighbor, relevance_probability)
if next_node in previous_nodes: # terminating condition
break
previous_nodes.add(current_node)
current_node = next_node
path.append(current_node)
is_root = False
return current_node, path
[docs] def g_s(self, args: Tuple[int, int,
bool]) -> Tuple[List[Tuple], List[List[int]]]:
"""sampling for community gan generator
Args:
args (Tuple[int, int, bool]): root, n_sample, only_neg
root (int): root node id
n_sample (int): num of motif sampled
only_neg (bool): only return negative samples
Returns:
Tuple[List[Tuple], List[List[int]]]: motifs, paths
"""
root, n_sample, only_neg = args
_motifs = []
_paths = []
for _ in range(2 * n_sample):
if len(_motifs) >= n_sample:
break
motif = [root]
path = [root]
for _ in range(1, self.motif_size):
v, p = self.g_v(motif)
if v is None:
break
motif.append(v)
path.extend(p)
if len(set(motif)) < self.motif_size:
continue
motif = tuple(sorted(motif))
if only_neg and motif in self.total_motifs:
continue
_motifs.append(motif)
_paths.append(path)
return _motifs, _paths
[docs] def run(self) -> Tuple[List[Tuple], List[List[int]]]:
"""sampling for community gan
Returns:
Tuple[List[Tuple], List[List[int]]]: motifs, paths.
"""
with multiprocessing.Pool(self.n_threads) as p:
motifs, paths = zip(*p.map(self.g_s, self.args))
return motifs, paths
######################################################################################
# END: This section of code is adapted from https://github.com/SamJia/CommunityGAN #
######################################################################################