Source code for egc.module.layers.inner_product_de
"""
layers
"""
import torch
import torch.nn.functional as F
from torch import nn
[docs]class InnerProductDecoder(nn.Module):
"""Decoder for using inner product for prediction."""
def __init__(self, dropout: float = 0.0, act=torch.sigmoid):
super().__init__()
self.dropout = dropout
self.act = act
[docs] def forward(self, z):
z = F.dropout(z, self.dropout, training=self.training)
adj = self.act(torch.mm(z, z.t()))
return adj