Source code for torch_geometric_signed_directed.utils.general.triplet_loss

from typing import Union
import random

import torch
import numpy as np

[docs]def triplet_loss_node_classification(y: Union[np.array, torch.Tensor], Z: torch.FloatTensor, n_sample: int, thre: float): r""" Triplet loss function for node classification. Arg types: * **y** (np.array or torch.Tensor) - Node labels. * **Z** (torch.FloatTensor) - Embedding matrix for nodes. * **n_sample** (int) - Number of samples. * **thre** (float) - Threshold value for the triplet differences to be counted. Return types: * **loss** (torch.FloatTensor) - Loss value. """ nclass = int(y.max() - y.min() + 1) if isinstance(y, torch.Tensor): y = y.cpu().numpy() all_nodes_array = np.arange(Z.shape[0]) n_sample_class = max((int)(n_sample / nclass), 32) loss = 0 for i in range(nclass): randInds1 = random.choices(all_nodes_array[y == i], k=n_sample_class) randInds2 = random.choices(all_nodes_array[y == i], k=n_sample_class) feats1 = Z[randInds1] feats2 = Z[randInds2] randInds_dif = random.choices( all_nodes_array[y != i], k=n_sample_class) feats_dif = Z[randInds_dif] # inner product: same class inner product should > dif class inner product inner_products = torch.sum(torch.mul(feats1, feats_dif-feats2), dim=1) dists = inner_products + thre mask = dists > 0 loss += torch.sum(torch.mul(dists, mask.float())) loss /= n_sample_class*nclass return loss