Source code for egc.utils.load_data

"""
Load dataset with DGL for Graph Clustering
Author: Sheng Zhou
"""
import os
import ssl
import sys
from typing import Tuple

import dgl
import numpy as np
import scipy.io as sio
import torch
import wget
from ogb.nodeproppred import DglNodePropPredDataset
from torch.utils.data import Dataset

ogb_datasets = ["products", "arxiv", "mag", "proteins"]
dgl_datasets = ["Cora", "Citeseer", "Pubmed", "CoraFull", "Reddit"]
other_datasets = ["BlogCatalog", "Flickr", "ACM", "DBLP"]

DEFAULT_DATA_DIR = "./data"

BASE_CoLA_LAB = "https://github.com/GRAND-Lab/CoLA/blob/main"
BLOGCATALOG_URL = f"{BASE_CoLA_LAB}/raw_dataset/BlogCatalog/BlogCatalog.mat?raw=true"
FLICKR_URL = f"{BASE_CoLA_LAB}/raw_dataset/Flickr/Flickr.mat?raw=true"
SDCN_URL = "https://github.com/bdy9527/SDCN/blob/da6bb007b7/"


[docs]def load_data( dataset_name: str, directory=DEFAULT_DATA_DIR, ) -> Tuple[dgl.DGLGraph, torch.Tensor, int]: """Load datasets. Args: dataset_name (str): Name of the dataset. Check README.md for supported datasets. directory (str, optional): path for the dataset to save. Defaults to './data'. Raises: NotImplementedError: dataset not supported Returns: Tuple[dgl.DGLGraph, torch.Tensor, int]: graph, label, n_clusters """ if dataset_name in ogb_datasets: graph, label, n_clusters = load_ogb_data(dataset_name, directory) elif dataset_name in dgl_datasets: graph, label, n_clusters = load_dgl_data(dataset_name, directory) elif dataset_name in other_datasets: # pylint: disable=protected-access ssl._create_default_https_context = ssl._create_unverified_context # `python -m` may fail to download datasets in other_datasets_map graph, label, n_clusters = other_datasets_map[dataset_name](directory) else: raise NotImplementedError return graph, label, n_clusters
[docs]def load_ogb_data(dataset_name, directory=DEFAULT_DATA_DIR): """ graph:DGL graph ob+ject label: torch tensor of shape (num_nodes,num_tasks) """ dataset = DglNodePropPredDataset( name="ogbn-" + dataset_name, root=directory, ) graph, label = dataset[0] print(f"\n NumNodes: {graph.num_nodes()}\n" f" NumEdges: {graph.num_edges()}\n" f" NumFeats: {graph.ndata['feat'].shape[1]}\n" f" NumClasses: {dataset.num_classes}\n") return graph, label, dataset.num_classes
[docs]def load_dgl_data(dataset_name, directory=DEFAULT_DATA_DIR): """ graph:DGL graph object label: form graph.ndata['label'] """ print("\n") if dataset_name not in ("Reddit", "CoraFull"): dataset = getattr( dgl.data, dataset_name + "GraphDataset", )(raw_dir=directory) else: dataset = getattr( dgl.data, dataset_name + "Dataset", )(raw_dir=directory) print(f" NumNodes: {dataset[0].num_nodes()}\n" f" NumEdges: {dataset[0].num_edges()}\n" f" NumFeats: {dataset[0].ndata['feat'].shape[1]}\n" f" NumClasses: {dataset.num_classes}") print("\n") graph = dataset[0] label = graph.ndata["label"] return graph, label, dataset.num_classes
[docs]def allclose( a: torch.Tensor, b: torch.Tensor, rtol: float = 1e-4, atol: float = 1e-4, ) -> bool: """This function checks if a and b satisfy the condition: \|a - b\| <= atol + rtol * \|b\| Args: a (torch.Tensor): first tensor to compare b (torch.Tensor): second tensor to compare rtol (float, optional): relative tolerance. Defaults to 1e-4. atol (float, optional): absolute tolerance. Defaults to 1e-4. Returns: bool: True for close, False for not """ return torch.allclose( a.float().cpu(), b.float().cpu(), rtol=rtol, atol=atol, )
[docs]def is_bidirected(g: dgl.DGLGraph) -> bool: """Return whether the graph is a bidirected graph. A graph is bidirected if for any edge :math:`(u, v)` in :math:`G` with weight :math:`w`, there exists an edge :math:`(v, u)` in :math:`G` with the same weight. Args: g (dgl.DGLGraph): dgl.DGLGraph Returns: bool: True for bidirected, False for not """ src, dst = g.edges() num_nodes = g.num_nodes() # Sort first by src then dst idx_src_dst = src * num_nodes + dst perm_src_dst = torch.argsort(idx_src_dst, dim=0, descending=False) src1, dst1 = src[perm_src_dst], dst[perm_src_dst] # Sort first by dst then src idx_dst_src = dst * num_nodes + src perm_dst_src = torch.argsort(idx_dst_src, dim=0, descending=False) src2, dst2 = src[perm_dst_src], dst[perm_dst_src] return allclose(src1, dst2) and allclose(src2, dst1)
[docs]class AE_LoadDataset(Dataset): """AE_LoadDataset""" def __init__(self, data): super().__init__() self.x = data def __len__(self): return self.x.shape[0] def __getitem__(self, idx): return torch.from_numpy(np.array( self.x[idx])).float(), torch.from_numpy(np.array(idx))
[docs]def load_mat_data2dgl(data_path, verbose=True): """ load data from .mat file Args: data_path (str): the file to read in verbose (bool, optional): print info, by default True Returns: graph (DGL.graph): the graph read from data_path (torch.Tensor): label of node classes num_classes (int): number of node classes """ mat_path = data_path data_mat = sio.loadmat(mat_path) adj = data_mat["Network"] feat = data_mat["Attributes"] # feat = preprocessing.normalize(feat, axis=0) labels = data_mat["Label"] labels = labels.flatten() graph = dgl.from_scipy(adj) graph.ndata["feat"] = torch.from_numpy(feat.toarray()).to(torch.float32) graph.ndata["label"] = torch.from_numpy(labels).to(torch.int64) num_classes = len(np.unique(labels)) if verbose: print() print(" DGL dataset") print(f" NumNodes: {graph.number_of_nodes()}") print(f" NumEdges: {graph.number_of_edges()}") print(f" NumFeats: {graph.ndata['feat'].shape[1]}") print(f" NumClasses: {num_classes}") return graph, graph.ndata["label"], num_classes
[docs]def bar_progress(current, total, _): """create this bar_progress method which is invoked automatically from wget""" progress_message = f"Downloading: {current / total * 100}% [{current} / {total}] bytes" sys.stdout.write("\r" + progress_message) sys.stdout.flush()
[docs]def load_BlogCatalog(raw_dir=DEFAULT_DATA_DIR): """ load BlogCatalog dgl graph Args: raw_dir (str): Data path. Supports user customization. Returns: graph (DGL.graph): the graph read from data_path (torch.Tensor): label of node classes num_classes (int): number of node classes Examples: >>> graph, label, n_clusters = load_BlogCatalog() """ data_file = os.path.join(raw_dir, "BlogCatalog.mat") if not os.path.exists(data_file): url = BLOGCATALOG_URL wget.download(url, out=data_file, bar=bar_progress) return load_mat_data2dgl(data_path=data_file)
[docs]def load_Flickr(raw_dir=DEFAULT_DATA_DIR): """ load Flickr dgl graph Args: raw_dir (str): Data path. Supports user customization. Returns: graph (DGL.graph): the graph read from data_path (torch.Tensor): label of node classes num_classes (int): number of node classes Examples: >>> graph, label, n_clusters = load_Flickr() """ data_file = os.path.join(raw_dir, "Flickr.mat") if not os.path.exists(data_file): url = FLICKR_URL wget.download(url, out=data_file, bar=bar_progress) return load_mat_data2dgl(data_path=data_file)
[docs]def load_ACM(raw_dir=DEFAULT_DATA_DIR, verbose=True): """ load ACM dgl graph Args: raw_dir (str): Data path. Supports user customization. verbose (bool, optional): print info, by default True Returns: graph (DGL.graph): the graph read from data_path (torch.Tensor): label of node classes num_classes (int): number of node classes Examples: >>> graph, label, n_clusters = load_ACM() """ if not os.path.exists(os.path.join(raw_dir, "acm")): os.mkdir(os.path.join(raw_dir, "acm")) adj_data_file = os.path.join(raw_dir, "acm", "adj.txt") if not os.path.exists(adj_data_file): url = SDCN_URL + "graph/acm_graph.txt?raw=true" wget.download(url, out=adj_data_file, bar=bar_progress) feat_data_file = os.path.join(raw_dir, "acm", "feat.txt") if not os.path.exists(feat_data_file): url = SDCN_URL + "data/acm.txt?raw=true" wget.download(url, out=feat_data_file, bar=bar_progress) label_data_file = os.path.join(raw_dir, "acm", "label.txt") if not os.path.exists(label_data_file): url = SDCN_URL + "data/acm_label.txt?raw=true" wget.download(url, out=label_data_file, bar=bar_progress) edges_unordered = np.genfromtxt(adj_data_file, dtype=np.int32) graph = dgl.graph((edges_unordered[:, 0], edges_unordered[:, 1])) feat = np.loadtxt(feat_data_file, dtype=float) labels = np.loadtxt(label_data_file, dtype=int) graph.ndata["feat"] = torch.from_numpy(feat).to(torch.float32) graph.ndata["label"] = torch.from_numpy(labels).to(torch.int64) num_classes = len(np.unique(labels)) if verbose: print() print(" DGL dataset") print(f" NumNodes: {graph.number_of_nodes()}") print(f" NumEdges: {graph.number_of_edges()}") print(f" NumFeats: {graph.ndata['feat'].shape[1]}") print(f" NumClasses: {num_classes}") return graph, graph.ndata["label"], num_classes
[docs]def load_DBLP(raw_dir=DEFAULT_DATA_DIR, verbose=True): """ load DBLP dgl graph Args: raw_dir (str): Data path. Supports user customization. verbose (bool, optional): print info, by default True Returns: graph (DGL.graph): the graph read from data_path (torch.Tensor): label of node classes num_classes (int): number of node classes Examples: >>> graph, label, n_clusters = load_DBLP() """ if not os.path.exists(os.path.join(raw_dir, "dblp")): os.mkdir(os.path.join(raw_dir, "dblp")) adj_data_file = os.path.join(raw_dir, "dblp", "adj.txt") if not os.path.exists(adj_data_file): url = SDCN_URL + "graph/dblp_graph.txt?raw=true" wget.download(url, out=adj_data_file, bar=bar_progress) feat_data_file = os.path.join(raw_dir, "dblp", "feat.txt") if not os.path.exists(feat_data_file): url = SDCN_URL + "data/dblp.txt?raw=true" wget.download(url, out=feat_data_file, bar=bar_progress) label_data_file = os.path.join(raw_dir, "dblp", "label.txt") if not os.path.exists(label_data_file): url = SDCN_URL + "data/dblp_label.txt?raw=true" wget.download(url, out=label_data_file, bar=bar_progress) edges_unordered = np.genfromtxt(adj_data_file, dtype=np.int32) graph = dgl.graph((edges_unordered[:, 0], edges_unordered[:, 1])) print(graph) feat = np.loadtxt(feat_data_file, dtype=float) labels = np.loadtxt(label_data_file, dtype=int) graph.ndata["feat"] = torch.from_numpy(feat).to(torch.float32) graph.ndata["label"] = torch.from_numpy(labels).to(torch.int64) num_classes = len(np.unique(labels)) if verbose: print() print(" DGL dataset") print(f" NumNodes: {graph.number_of_nodes()}") print(f" NumEdges: {graph.number_of_edges()}") print(f" NumFeats: {graph.ndata['feat'].shape[1]}") print(f" NumClasses: {num_classes}") return graph, graph.ndata["label"], num_classes
other_datasets_map = { "ACM": load_ACM, "Flickr": load_Flickr, "BlogCatalog": load_BlogCatalog, "DBLP": load_DBLP, } # # for test # if __name__ == "__main__": # import pandas as pd # import scipy.sparse as sp # dataset_list = [ # 'Cora', 'Citeseer', 'Pubmed', 'BlogCatalog', 'Flickr', 'ACM' # ] # n_nodes_list, n_edges_list, n_attr_list, n_class_list = [], [], [], [] # for data_name in dataset_list: # print('-' * 10, data_name, '-' * 10) # graph, label, n_clusters = load_data(data_name) # print(is_bidirected(graph)) # print(label) # n_nodes_list.append(graph.number_of_nodes()) # n_edges_list.append(graph.number_of_edges()) # n_attr_list.append(graph.ndata['feat'].shape[1]) # n_class_list.append(n_clusters) # df_dataset = pd.DataFrame({ # 'Datasets': dataset_list, # 'Nodes': n_nodes_list, # 'Edges': n_edges_list, # 'Attributes': n_attr_list, # 'Classes': n_class_list # }) # print(df_dataset)