Source code for egc.utils.metrics

"""Metrics
"""
import torch

######################################################################################
# START: This section of code is adapted from https://github.com/bwilder0/clusternet #
######################################################################################


[docs]def get_soft_assignment_matrix( data: torch.Tensor, miu: torch.Tensor, cluster_temp: float = 30, dist_type: str = "cosine_similarity", ) -> torch.Tensor: """Get soft assignment matrix from data points and cluster centers. Args: data (torch.Tensor): data embeddings. miu (torch.Tensor): cluster center embeddings. cluster_temp (float, optional): softmax temperature. Defaults to 30. dist_type (str, optional): distance type. Defaults to 'cosine_similarity'. Returns: torch.Tensor: soft assignment matrix. """ n = data.shape[0] d = data.shape[1] k = miu.shape[0] if dist_type == "cosine_similarity": dist = torch.cosine_similarity( data[:, None].expand(n, k, d).reshape((-1, d)), miu[None].expand(n, k, d).reshape((-1, d)), ).reshape((n, k)) elif dist_type == "dot": dist = data @ miu.t() soft_assignment_matrix = torch.softmax(cluster_temp * dist, 1) return soft_assignment_matrix
[docs]def get_modularity_matrix(adj_nodia: torch.Tensor) -> torch.Tensor: """Get Modularity Matrix. .. math:: A_{vw} - \\frac{K_vk_w}{2m} Args: adj (torch.Tensor): adjacency matrix without diag. Returns: torch.Tensor: modularity matrix. """ degrees = adj_nodia.sum(dim=0).unsqueeze(1) mod = adj_nodia - degrees @ degrees.t() / adj_nodia.sum() return mod
[docs]def get_modularity_value( bin_adj_nodiag: torch.Tensor, r: torch.Tensor, mod: torch.Tensor, ) -> torch.Tensor: """Get Modularity. .. math:: Q(r)=\\frac{1}{2m}\\sum_{u,v\\in V}\\sum_{k=1}^K[A_{uv}-\\frac{d_ud_v}{2m}]r_{uk}r_{vk} Args: bin_adj_nodiag (torch.Tensor): n x n. Boolean adj matrix without diag. r (torch.Tensor): n x k. Soft assignment probability matrix. mod (torch.Tensor): n x n. Modularity matrix. Returns: torch.Tensor: Modularity value. """ return (1.0 / bin_adj_nodiag.sum()) * (r.t() @ mod @ r).trace()
###################################################################################### # END: This section of code is adapted from https://github.com/bwilder0/clusternet # ######################################################################################