Source code for egc.utils.graph_augmentation.transforms.feature_dropout

"""Feature Dropout"""
import torch.nn.functional as F
from dgl import BaseTransform


[docs]class FeatureDropout(BaseTransform): r"""Augment features by randomly masking node feautres with 0. Parameters ---------- p : float, optional Probability of a node feautre to be masked. Example ------- >>> import dgl >>> import torch >>> from dglAug import FeatureDropout >>> transform = FeatureDropout(p=0.2) >>> g = dgl.rand_graph(4,2) >>> g.ndata['feat'] = torch.rand((4,5)) >>> print(g.ndata['feat']) tensor([[0.7706, 0.3505, 0.1246, 0.5076, 0.3071], [0.5388, 0.6082, 0.5088, 0.8058, 0.4955], [0.7638, 0.3115, 0.4265, 0.5507, 0.4404], [0.3127, 0.0056, 0.1876, 0.9971, 0.6389]]) >>> g = transform(g) >>> print(g.ndata['feat']) tensor([[0.0000, 0.0000, 0.1558, 0.6345, 0.3839], [0.0000, 0.7603, 0.6360, 0.0000, 0.6194], [0.0000, 0.3893, 0.5331, 0.0000, 0.5505], [0.3909, 0.0070, 0.2345, 1.2464, 0.0000]]) """ def __init__(self, p=0.5): self.p = p def __call__(self, g): if self.p == 0: return g for ntype in g.ntypes: g.apply_nodes( lambda node: {"feat": F.dropout(node.data["feat"], self.p)}, ntype=ntype, ) return g
# if __name__ == '__main__': # transform = FeatureDropout(p=0.2) # import dgl # g = dgl.rand_graph(4, 2) # g.ndata['feat'] = torch.rand((4, 5)) # print(g.ndata['feat']) # g = transform(g) # print(g.ndata['feat'])