Source code for egc.utils.graph_augmentation.transforms.node_shuffle

"""Randomly shuffle"""
from dgl import backend as F
from dgl import BaseTransform


[docs]class NodeShuffle(BaseTransform): r"""Randomly shuffle the nodes. Example ------- >>> import dgl >>> import torch >>> from dgl import NodeShuffle >>> transform = NodeShuffle() >>> g = dgl.graph(([0, 1], [1, 2])) >>> g.ndata['h1'] = torch.tensor([[1., 2.], [3., 4.], [5., 6.]]) >>> g.ndata['h2'] = torch.tensor([[7., 8.], [9., 10.], [11., 12.]]) >>> g = transform(g) >>> print(g.ndata['h1']) tensor([[5., 6.], [3., 4.], [1., 2.]]) >>> print(g.ndata['h2']) tensor([[11., 12.], [ 9., 10.], [ 7., 8.]]) """ def __init__(self, is_use=True): self.is_use = is_use def __call__(self, g): if not self.is_use: return g for ntype in g.ntypes: nids = F.astype(g.nodes(ntype), F.int64) perm = F.rand_shuffle(nids) for key, feat in g.nodes[ntype].data.items(): g.nodes[ntype].data[key] = feat[perm] return g