Source code for egc.module.loss.contrastive_loss

"""Contrastive Loss
Adapted from https://github.com/Yunfan-Li/Contrastive-Clustering
"""
import math

import torch
from torch import nn


[docs]def mask_correlated_samples(batch_size): N = 2 * batch_size mask = torch.ones((N, N)) mask = mask.fill_diagonal_(0) for i in range(batch_size): mask[i, batch_size + i] = 0 mask[batch_size + i, i] = 0 mask = mask.bool() return mask
[docs]class InstanceLoss(nn.Module): """Instance Contrastive Loss""" def __init__(self, batch_size, temperature): super().__init__() self.batch_size = batch_size self.temperature = temperature self.mask = mask_correlated_samples(batch_size) self.criterion = nn.CrossEntropyLoss(reduction="sum")
[docs] def forward(self, z_i, z_j): N = 2 * self.batch_size z = torch.cat((z_i, z_j), dim=0) sim = torch.matmul(z, z.T) / self.temperature sim_i_j = torch.diag(sim, self.batch_size) sim_j_i = torch.diag(sim, -self.batch_size) positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) negative_samples = sim[self.mask].reshape(N, -1) labels = torch.zeros(N).long().cuda() logits = torch.cat((positive_samples, negative_samples), dim=1) loss = self.criterion(logits, labels) loss /= N return loss
[docs]class ClusterLoss(nn.Module): """Cluster Contrastive Loss""" def __init__(self, class_num, temperature): super().__init__() self.class_num = class_num self.temperature = temperature self.mask = mask_correlated_samples(class_num) self.criterion = nn.CrossEntropyLoss(reduction="sum") self.similarity_f = nn.CosineSimilarity(dim=2)
[docs] def forward(self, c_i, c_j): p_i = c_i.sum(0).view(-1) p_i /= p_i.sum() ne_i = math.log(p_i.size(0)) + (p_i * torch.log(p_i)).sum() p_j = c_j.sum(0).view(-1) p_j /= p_j.sum() ne_j = math.log(p_j.size(0)) + (p_j * torch.log(p_j)).sum() ne_loss = ne_i + ne_j c_i = c_i.t() c_j = c_j.t() N = 2 * self.class_num c = torch.cat((c_i, c_j), dim=0) sim = self.similarity_f(c.unsqueeze(1), c.unsqueeze(0)) / self.temperature sim_i_j = torch.diag(sim, self.class_num) sim_j_i = torch.diag(sim, -self.class_num) positive_clusters = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1) negative_clusters = sim[self.mask].reshape(N, -1) labels = torch.zeros(N).long().cuda() logits = torch.cat((positive_clusters, negative_clusters), dim=1) loss = self.criterion(logits, labels) loss /= N return loss + ne_loss