Source code for egc.utils.construct_DGLgraph

"""
construct_DGLgraph
"""
import collections

import dgl
import numpy as np
import torch
from sklearn.metrics import pairwise_distances as pair


[docs]def construct_DGLgraph_for_non_graph(x, labels, k=3, method="euclidean"): if method == "heat": return construct_DGLgraph_for_non_graph_by_heat(x, labels, k) if torch.is_tensor(x) is False: x = torch.Tensor(x) if torch.is_tensor(labels) is False: labels = torch.Tensor(labels) knn_g = dgl.knn_graph(x, k, dist=method) edges = knn_g.edges() edges = np.array(process_edges_info(edges)) return build_graph(x, edges, x.shape[0], labels)
[docs]def construct_DGLgraph_for_non_graph_by_heat(x, labels, k=3): if torch.is_tensor(x) is False: x = torch.Tensor(x) if torch.is_tensor(labels) is False: labels = torch.Tensor(labels) dist = None dist = -0.5 * pair(x)**2 # 时间参数t默认为2?????,计算任意两个节点之间的距离,并返回一个矩阵 dist = np.exp(dist) # 得到相似度矩阵 inds = [] for i in range(dist.shape[0]): ind = np.argpartition(dist[i, :], -(k + 1))[-(k + 1):] # 获取k个最相似的节点的下标 inds.append(ind) src = [] dst = [] counter = 0 for i, v in enumerate(inds): for vv in v: if vv == i: pass else: if labels[vv] != labels[i]: counter += 1 src.append(i) dst.append(vv) edges = torch.stack((torch.Tensor(src), torch.Tensor(dst)), dim=1) edges = np.array(process_edges_info(edges)) return build_graph(x, edges, x.shape[0], labels)
[docs]def construct_DGLgraph_for_graph(x, labels, edges): if torch.is_tensor(x) is False: x = torch.Tensor(x) if torch.is_tensor(labels) is False: labels = torch.Tensor(labels) edges = np.array(process_edges_info(edges)) return build_graph(x, edges, x.shape[0], labels)
[docs]def process_edges_info( edges): # 由于DGL图会将重复的边也算进边的总数内,所有要去除重复的边和自环,到后面统一添加自环,并且DGL edge_dict = collections.defaultdict(bool) pair_list = [] u = edges[0] v = edges[1] length = len(edges[0]) for i in range(length): if (edge_dict[(u[i].item(), v[i].item())] is False and u[i].item() != v[i].item()): edge_dict[(u[i].item(), v[i].item())] = True edge_dict[(v[i].item(), u[i].item())] = True pair_list.append((u[i].item(), v[i].item())) pair_list.append((v[i].item(), u[i].item())) return pair_list
[docs]def build_graph(features, edges, num_nodes, labels): # 构图,减少代码重复 graph = dgl.graph((edges[0], edges[1]), num_nodes=num_nodes) graph.ndata["feat"] = features graph.ndata["label"] = labels graph = dgl.add_self_loop(graph) return graph