Source code for egc.utils.data_loader.naive_data_loader

"""Naive Graph Batch Data Loader
"""
from typing import List

import dgl
import torch

from ..graph_augmentation import ComposeAug
from ..graph_augmentation import get_augments


[docs]class NaiveDataLoader: """Naive DataLoader using full neighbor sampler Args: graph (dgl.DGLGraph): graph. batch_size (int, optional): batch size. Defaults to 1024. n_layers (int, optional): GNN layers. Defaults to 1. fanout (List[int] or int, optional): List of neighbors to sample for each GNN layer, with the i-th element being the fanout for the i-th GNN layer.Defaults to -1. - If only a single integer is provided, DGL assumes that every layer will have the same fanout. - If -1 is provided on one layer, then all inbound edges will be included. aug_types (List, optional): augmentation types list. Defaults to ["none"]. drop_last (bool, optional): set to True to drop the last incomplete batch. Defaults to False. Raises: ValueError: raise if exists any augmentation type not supported """ def __init__( self, graph: dgl.DGLGraph, batch_size: int = 1024, n_layers: int = 1, fanouts: List[int] or int = -1, aug_types: List = None, device: torch.device = torch.device("cpu"), drop_last: bool = False, ) -> None: aug_types = ["none"] if aug_types is None else aug_types aug_list = get_augments(aug_types) transform = ComposeAug(aug_list, cross=False) self.aug_graphs = transform(graph) sampler = dgl.dataloading.NeighborSampler( fanouts=[fanouts] * n_layers if isinstance(fanouts, int) else fanouts) self.data_loaders = [ dgl.dataloading.DataLoader( g, g.nodes(), sampler, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=0 if torch.cuda.is_available() else 12, device=device, ) for g in self.aug_graphs ] self.device = device
[docs] def get_aug_graphs(self) -> List[dgl.DGLGraph]: """Get the augmented graphs. Returns: List[dgl.DGLGraph]: list of graphs augmented. """ return self.aug_graphs
def __iter__(self): """Return the iterator of the data loader.""" return _DataLoaderIter(self) def __len__(self): """Return the number of batches of the data loader.""" return len(self.data_loaders[0])
class _DataLoaderIter: def __init__(self, data_loaders): self.device = data_loaders.device self.data_loaders = data_loaders.data_loaders self.iters_ = [iter(data_loader) for data_loader in self.data_loaders] # Make this an iterator for PyTorch Lightning compatibility def __iter__(self): return self def __next__(self): # [(input_nodes, output_nodes, blocks),...] return [(_to_device(data, self.device) for data in next(iter_)) for iter_ in self.iters_] def _to_device(data, device): if isinstance(data, dict): for k, v in data.items(): data[k] = v.to(device) elif isinstance(data, list): data = [item.to(device) for item in data] else: data = data.to(device) return data # def iter_data_loader(data_loaders: List) -> Generator: # iter_dls = [iter(dl) for dl in data_loaders] # for _ in range(len(data_loaders[0])): # try: # yield [(next(dl)) for dl in iter_dls] # except StopIteration: # return # for test only # if __name__ == '__main__': # from utils import load_data # from utils import augment_graph # import torch # graph, _, _ = load_data('Cora') # from module.data_loader.NaiveDataLoader import NaiveDataLoader # dl = NaiveDataLoader(graph) # for i,j in dl: # print(torch.all(i[2][0].dstdata['feat'][0]==j[2][0].dstdata['feat'][0])) # for i,j in dl: # print(i[2][0].srcdata['feat'].shape,j[2][0].srcdata['feat'].shape)