torch_geometric_signed_directed.utils.general.triplet_loss
Functions
|
Triplet loss function for node classification. |
Module Contents
- triplet_loss_node_classification(y: numpy.array | torch.Tensor, Z: torch.FloatTensor, n_sample: int, thre: float)
Triplet loss function for node classification.
- Arg types:
y (np.array or torch.Tensor) - Node labels.
Z (torch.FloatTensor) - Embedding matrix for nodes.
n_sample (int) - Number of samples.
thre (float) - Threshold value for the triplet differences to be counted.
- Return types:
loss (torch.FloatTensor) - Loss value.