Source code for egc.module.layers.multilayer_gnn

"""MultiLayer GraphSAGE"""
from typing import List

from dgl.nn.pytorch.conv import SAGEConv
from torch import nn

from ...utils import act_map


[docs]class MultiLayerGNN(nn.Module): """MultiLayer GraphSAGE with different types of aggregator_type. Args: in_feats (int): Input feature dimension. out_feats_list (List[int]): List of hidden units dimensions. aggregator_type (str, optional): Aggregate type of sage. Defaults to 'gcn'. bias (bool, optional): Whether to apply bias. Defaults to True. """ def __init__( self, in_feats: int, out_feats_list: List[int], aggregator_type: str = "gcn", bias: bool = True, activation: List[str] = None, dropout: float = 0.0, ) -> None: super().__init__() pre_layer = in_feats for idx, hidden_units in enumerate(out_feats_list): setattr( self, f"conv{idx}", SAGEConv( in_feats=pre_layer, out_feats=hidden_units, aggregator_type=aggregator_type, bias=bias, feat_drop=dropout, activation=nn.ReLU() if activation is None else act_map[activation[idx]], ), ) pre_layer = hidden_units
[docs] def forward(self, blocks, x, edge_weight=None): for idx, block in enumerate(blocks): conv = getattr(self, f"conv{idx}") x = conv(block, x, edge_weight) return x