torch_geometric_signed_directed.utils.general.triplet_loss

Functions

triplet_loss_node_classification(y, Z, n_sample, thre)

Triplet loss function for node classification.

Module Contents

triplet_loss_node_classification(y: numpy.array | torch.Tensor, Z: torch.FloatTensor, n_sample: int, thre: float)

Triplet loss function for node classification.

Arg types:
  • y (np.array or torch.Tensor) - Node labels.

  • Z (torch.FloatTensor) - Embedding matrix for nodes.

  • n_sample (int) - Number of samples.

  • thre (float) - Threshold value for the triplet differences to be counted.

Return types:
  • loss (torch.FloatTensor) - Loss value.