Source code for torch_geometric_signed_directed.utils.directed.prob_imbalance_loss

from typing import Optional, Union

import torch
import numpy as np


[docs]class Prob_Imbalance_Loss(torch.nn.Module): r"""An implementation of the probabilistic imbalance loss function from the `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper. Args: F (int or NumPy array, optional) - Number of pairwise imbalance socres to consider, or the meta-graph adjacency matrix. """ def __init__(self, F: Optional[Union[int, np.ndarray]] = None): super(Prob_Imbalance_Loss, self).__init__() if isinstance(F, int): self.sel = F elif F is not None: K = F.shape[0] self.sel = 0 for i in range(K-1): for j in range(i+1, K): if (F[i, j] + F[j, i]) > 0: self.sel += 1
[docs] def forward(self, P: torch.FloatTensor, A: Union[torch.FloatTensor, torch.sparse_coo_tensor], K: int, normalization: str = 'vol_sum', threshold: str = 'sort') -> torch.FloatTensor: """Making a forward pass of the probabilistic imbalance loss function from the `DIGRAC: Digraph Clustering Based on Flow Imbalance" <https://arxiv.org/pdf/2106.05194.pdf>`_ paper. Arg types: * **prob** (PyTorch FloatTensor) - Prediction probability matrix made by the model * **A** (PyTorch FloatTensor, can be sparse) - Adjacency matrix A * **K** (int) - Number of clusters * **normalization** (str, optional) - normalization method: 'vol_sum': Normalized by the sum of volumes, the default choice. 'vol_max': Normalized by the maximum of volumes. 'vol_min': Normalized by the minimum of volumes. 'plain': No normalization, just CI. * **threshold**: (str, optional) normalization method: 'sort': Picking the top beta imbalnace values, the default choice. 'std': Picking only the terms 3 standard deviation away from null hypothesis. 'naive': No thresholding, suming up all K*(K-1)/2 terms of imbalance values. Return types: * **loss** (torch.Tensor) - loss value, roughly in [0,1]. """ assert normalization in ['vol_sum', 'vol_min', 'vol_max', 'plain'], 'Please input the correct normalization method name!' assert threshold in [ 'sort', 'std', 'naive'], 'Please input the correct threshold method name!' device = A.device # avoid zero volumn to be denominator epsilon = torch.FloatTensor([1e-8]).to(device) # first calculate the probabilitis volumns for each cluster vol = torch.zeros(K).to(device) for k in range(K): vol[k] = torch.sum(torch.matmul( A + torch.transpose(A, 0, 1), P[:, k:k+1])) second_max_vol = torch.topk(vol, 2).values[1] + epsilon result = torch.zeros(1).to(device) imbalance = [] if threshold == 'std': imbalance_std = [] for k in range(K-1): for l in range(k+1, K): w_kl = torch.matmul(P[:, k], torch.matmul(A, P[:, l])) w_lk = torch.matmul(P[:, l], torch.matmul(A, P[:, k])) if (w_kl-w_lk).item() != 0: if threshold != 'std' or np.power((w_kl-w_lk).item(), 2)-9*(w_kl+w_lk).item() > 0: if normalization == 'vol_sum': curr = torch.abs(w_kl-w_lk) / \ (vol[k] + vol[l] + epsilon) * 2 elif normalization == 'vol_min': curr = torch.abs( w_kl-w_lk)/(w_kl + w_lk)*torch.min(vol[k], vol[l])/second_max_vol elif normalization == 'vol_max': curr = torch.abs( w_kl-w_lk)/(torch.max(vol[k], vol[l]) + epsilon) elif normalization == 'plain': curr = torch.abs(w_kl-w_lk)/(w_kl + w_lk) imbalance.append(curr) else: # below-threshold values in the 'std' thresholding scheme if normalization == 'vol_sum': curr = torch.abs(w_kl-w_lk) / \ (vol[k] + vol[l] + epsilon) * 2 elif normalization == 'vol_min': curr = torch.abs( w_kl-w_lk)/(w_kl + w_lk)*torch.min(vol[k], vol[l])/second_max_vol elif normalization == 'vol_max': curr = torch.abs( w_kl-w_lk)/(torch.max(vol[k], vol[l]) + epsilon) elif normalization == 'plain': curr = torch.abs(w_kl-w_lk)/(w_kl + w_lk) imbalance_std.append(curr) imbalance_values = [curr.item() for curr in imbalance] if threshold == 'sort': # descending order ind_sorted = np.argsort(-np.array(imbalance_values)) for ind in ind_sorted[:int(self.sel)]: result += imbalance[ind] # take negation to be minimized return torch.ones(1, requires_grad=True).to(device) - result/self.sel elif len(imbalance) > 0: return torch.ones(1, requires_grad=True).to(device) - torch.mean(torch.FloatTensor(imbalance)) elif threshold == 'std': # sel is 0, then disregard thresholding return torch.ones(1, requires_grad=True).to(device) - torch.mean(torch.FloatTensor(imbalance_std)) else: # nothing has positive imbalance return torch.ones(1, requires_grad=True).to(device)