Source code for egc.module.layers.selfexpr_secomm

"""
Self-Expressive module for SEComm
"""
import torch
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter


[docs]class SECommSelfExpr(Module): """ Self-Expressive module for SEComm """ def __init__(self, n): super().__init__() self.n = n self.weight = Parameter(torch.FloatTensor(n, n).uniform_(0, 0.01))
[docs] def forward(self, x): # self.weight.data = F.relu(self.weight) output = torch.mm( self.weight - torch.diag(torch.diagonal(self.weight)), x) return self.weight, output
[docs] def reset(self): self.weight.data = torch.FloatTensor(self.n, self.n).uniform_(0, 0.01)