from typing import Optional, Union
import torch
import numpy as np
[docs]class Prob_Imbalance_Loss(torch.nn.Module):
r"""An implementation of the probabilistic imbalance loss function from the
`DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper.
Args:
F (int or NumPy array, optional) - Number of pairwise imbalance socres to consider, or the meta-graph adjacency matrix.
"""
def __init__(self, F: Optional[Union[int, np.ndarray]] = None):
super(Prob_Imbalance_Loss, self).__init__()
if isinstance(F, int):
self.sel = F
elif F is not None:
K = F.shape[0]
self.sel = 0
for i in range(K-1):
for j in range(i+1, K):
if (F[i, j] + F[j, i]) > 0:
self.sel += 1
[docs] def forward(self, P: torch.FloatTensor, A: Union[torch.FloatTensor, torch.sparse_coo_tensor],
K: int, normalization: str = 'vol_sum', threshold: str = 'sort') -> torch.FloatTensor:
"""Making a forward pass of the probabilistic imbalance loss function from the
`DIGRAC: Digraph Clustering Based on Flow Imbalance" <https://arxiv.org/pdf/2106.05194.pdf>`_ paper.
Arg types:
* **prob** (PyTorch FloatTensor) - Prediction probability matrix made by the model
* **A** (PyTorch FloatTensor, can be sparse) - Adjacency matrix A
* **K** (int) - Number of clusters
* **normalization** (str, optional) - normalization method:
'vol_sum': Normalized by the sum of volumes, the default choice.
'vol_max': Normalized by the maximum of volumes.
'vol_min': Normalized by the minimum of volumes.
'plain': No normalization, just CI.
* **threshold**: (str, optional) normalization method:
'sort': Picking the top beta imbalnace values, the default choice.
'std': Picking only the terms 3 standard deviation away from null hypothesis.
'naive': No thresholding, suming up all K*(K-1)/2 terms of imbalance values.
Return types:
* **loss** (torch.Tensor) - loss value, roughly in [0,1].
"""
assert normalization in ['vol_sum', 'vol_min', 'vol_max',
'plain'], 'Please input the correct normalization method name!'
assert threshold in [
'sort', 'std', 'naive'], 'Please input the correct threshold method name!'
device = A.device
# avoid zero volumn to be denominator
epsilon = torch.FloatTensor([1e-8]).to(device)
# first calculate the probabilitis volumns for each cluster
vol = torch.zeros(K).to(device)
for k in range(K):
vol[k] = torch.sum(torch.matmul(
A + torch.transpose(A, 0, 1), P[:, k:k+1]))
second_max_vol = torch.topk(vol, 2).values[1] + epsilon
result = torch.zeros(1).to(device)
imbalance = []
if threshold == 'std':
imbalance_std = []
for k in range(K-1):
for l in range(k+1, K):
w_kl = torch.matmul(P[:, k], torch.matmul(A, P[:, l]))
w_lk = torch.matmul(P[:, l], torch.matmul(A, P[:, k]))
if (w_kl-w_lk).item() != 0:
if threshold != 'std' or np.power((w_kl-w_lk).item(), 2)-9*(w_kl+w_lk).item() > 0:
if normalization == 'vol_sum':
curr = torch.abs(w_kl-w_lk) / \
(vol[k] + vol[l] + epsilon) * 2
elif normalization == 'vol_min':
curr = torch.abs(
w_kl-w_lk)/(w_kl + w_lk)*torch.min(vol[k], vol[l])/second_max_vol
elif normalization == 'vol_max':
curr = torch.abs(
w_kl-w_lk)/(torch.max(vol[k], vol[l]) + epsilon)
elif normalization == 'plain':
curr = torch.abs(w_kl-w_lk)/(w_kl + w_lk)
imbalance.append(curr)
else: # below-threshold values in the 'std' thresholding scheme
if normalization == 'vol_sum':
curr = torch.abs(w_kl-w_lk) / \
(vol[k] + vol[l] + epsilon) * 2
elif normalization == 'vol_min':
curr = torch.abs(
w_kl-w_lk)/(w_kl + w_lk)*torch.min(vol[k], vol[l])/second_max_vol
elif normalization == 'vol_max':
curr = torch.abs(
w_kl-w_lk)/(torch.max(vol[k], vol[l]) + epsilon)
elif normalization == 'plain':
curr = torch.abs(w_kl-w_lk)/(w_kl + w_lk)
imbalance_std.append(curr)
imbalance_values = [curr.item() for curr in imbalance]
if threshold == 'sort':
# descending order
ind_sorted = np.argsort(-np.array(imbalance_values))
for ind in ind_sorted[:int(self.sel)]:
result += imbalance[ind]
# take negation to be minimized
return torch.ones(1, requires_grad=True).to(device) - result/self.sel
elif len(imbalance) > 0:
return torch.ones(1, requires_grad=True).to(device) - torch.mean(torch.FloatTensor(imbalance))
elif threshold == 'std': # sel is 0, then disregard thresholding
return torch.ones(1, requires_grad=True).to(device) - torch.mean(torch.FloatTensor(imbalance_std))
else: # nothing has positive imbalance
return torch.ones(1, requires_grad=True).to(device)