torch_geometric_signed_directed.utils.directed.prob_imbalance_loss
Classes
An implementation of the probabilistic imbalance loss function from the |
Module Contents
- class Prob_Imbalance_Loss(F: int | numpy.ndarray | None = None)
Bases:
torch.nn.ModuleAn implementation of the probabilistic imbalance loss function from the DIGRAC: Digraph Clustering Based on Flow Imbalance paper.
- Parameters:
F (int or NumPy array, optional)
- forward(P: torch.FloatTensor, A: torch.FloatTensor | torch.sparse_coo_tensor, K: int, normalization: str = 'vol_sum', threshold: str = 'sort') torch.FloatTensor
Making a forward pass of the probabilistic imbalance loss function from the DIGRAC: Digraph Clustering Based on Flow Imbalance” paper.
- Arg types:
prob (PyTorch FloatTensor) - Prediction probability matrix made by the model
A (PyTorch FloatTensor, can be sparse) - Adjacency matrix A
K (int) - Number of clusters
normalization (str, optional) - normalization method:
‘vol_sum’: Normalized by the sum of volumes, the default choice.
‘vol_max’: Normalized by the maximum of volumes.
‘vol_min’: Normalized by the minimum of volumes.
‘plain’: No normalization, just CI.
threshold: (str, optional) normalization method:
‘sort’: Picking the top beta imbalnace values, the default choice.
‘std’: Picking only the terms 3 standard deviation away from null hypothesis.
‘naive’: No thresholding, suming up all K*(K-1)/2 terms of imbalance values.
- Return types:
loss (torch.Tensor) - loss value, roughly in [0,1].