torch_geometric_signed_directed.utils.directed.prob_imbalance_loss

Classes

Prob_Imbalance_Loss

An implementation of the probabilistic imbalance loss function from the

Module Contents

class Prob_Imbalance_Loss(F: int | numpy.ndarray | None = None)

Bases: torch.nn.Module

An implementation of the probabilistic imbalance loss function from the DIGRAC: Digraph Clustering Based on Flow Imbalance paper.

Parameters:

F (int or NumPy array, optional)

forward(P: torch.FloatTensor, A: torch.FloatTensor | torch.sparse_coo_tensor, K: int, normalization: str = 'vol_sum', threshold: str = 'sort') torch.FloatTensor

Making a forward pass of the probabilistic imbalance loss function from the DIGRAC: Digraph Clustering Based on Flow Imbalance” paper.

Arg types:
  • prob (PyTorch FloatTensor) - Prediction probability matrix made by the model

  • A (PyTorch FloatTensor, can be sparse) - Adjacency matrix A

  • K (int) - Number of clusters

  • normalization (str, optional) - normalization method:

    ‘vol_sum’: Normalized by the sum of volumes, the default choice.

    ‘vol_max’: Normalized by the maximum of volumes.

    ‘vol_min’: Normalized by the minimum of volumes.

    ‘plain’: No normalization, just CI.

  • threshold: (str, optional) normalization method:

    ‘sort’: Picking the top beta imbalnace values, the default choice.

    ‘std’: Picking only the terms 3 standard deviation away from null hypothesis.

    ‘naive’: No thresholding, suming up all K*(K-1)/2 terms of imbalance values.

Return types:
  • loss (torch.Tensor) - loss value, roughly in [0,1].