Source code for torch_geometric_signed_directed.utils.signed.link_sign_loss

import torch
import scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import (negative_sampling,

[docs]class Sign_Triangle_Loss(nn.Module): r"""An implementation of the Signed Triangle Loss used in `"SDGNN: Learning Node Representation for Signed Directed Networks" <>`_ paper. Args: emb_dim (int): The embedding size. """ def __init__(self, emb_dim: int, edge_weight: sp.csc_matrix ) -> None: super().__init__() self.lin = nn.Linear(emb_dim * 2, 1) self.edge_weight = edge_weight
[docs] def forward( self, z: torch.Tensor, pos_edge_index: torch.LongTensor, neg_edge_index: torch.LongTensor ) -> torch.Tensor: device = z.device z_11 = z[pos_edge_index[0], :] z_12 = z[pos_edge_index[1], :] ind1 = pos_edge_index[0].cpu().numpy().tolist() ind2 = pos_edge_index[1].cpu().numpy().tolist() edge_w1 = torch.from_numpy(self.edge_weight[ind1, ind2]).reshape(-1, 1).to(device) z_21 = z[neg_edge_index[0], :] z_22 = z[neg_edge_index[1], :] ind1 = neg_edge_index[0].cpu().numpy().tolist() ind2 = neg_edge_index[1].cpu().numpy().tolist() edge_w2 = torch.from_numpy(self.edge_weight[ind1, ind2]).reshape(-1, 1).to(device) rs1 = self.lin([z_11, z_12], dim=1)) rs2 = self.lin([z_21, z_22], dim=1)) pos_loss = F.binary_cross_entropy_with_logits(rs1, torch.ones_like(rs1), weight=edge_w1, reduction='sum') neg_loss = F.binary_cross_entropy_with_logits(rs2, torch.zeros_like(rs2), weight=edge_w2, reduction='sum') return pos_loss + neg_loss
[docs]class Sign_Direction_Loss(nn.Module): r"""An implementation of the Signed Direction Loss used in `"SDGNN: Learning Node Representation for Signed Directed Networks" <>`_ paper. Args: emb_dim (int): The embedding size. """ def __init__(self, emb_dim: int) -> None: super().__init__() self.score_function1 = nn.Sequential( nn.Linear(emb_dim, 1), nn.Sigmoid() ) self.score_function2 = nn.Sequential( nn.Linear(emb_dim, 1), nn.Sigmoid() )
[docs] def forward( self, z: torch.Tensor, pos_edge_index: torch.LongTensor, neg_edge_index: torch.LongTensor ) -> torch.Tensor: z_11 = z[pos_edge_index[0], :] z_12 = z[pos_edge_index[1], :] z_21 = z[neg_edge_index[0], :] z_22 = z[neg_edge_index[1], :] s1 = self.score_function1(z_11) s2 = self.score_function2(z_12) q = torch.where((s1 - s2) > -0.5, torch.ones_like(s1) * -0.5, s1 - s2) tmp = (q - (s1 - s2)) pos_loss = torch.einsum("ij,ij->i", [tmp, tmp]).sum() s1 = self.score_function1(z_21) s2 = self.score_function2(z_22) q = torch.where((s1 - s2) > 0.5, s1 - s2, torch.ones_like(s1) * 0.5) tmp = (q - (s1 - s2)) neg_loss = torch.einsum("ij,ij->i", [tmp, tmp]).sum() return pos_loss + neg_loss
[docs]class Sign_Product_Entropy_Loss(nn.Module): r"""An implementation of the Signed Entropy Loss used in `"SDGNN: Learning Node Representation for Signed Directed Networks" <>`_ paper. """ def __init__(self) -> None: super().__init__()
[docs] def forward( self, z: torch.Tensor, pos_edge_index: torch.LongTensor, neg_edge_index: torch.LongTensor ) -> torch.Tensor: z_11 = z[pos_edge_index[0], :] z_12 = z[pos_edge_index[1], :] z_21 = z[neg_edge_index[0], :] z_22 = z[neg_edge_index[1], :] product1 = torch.einsum("ij, ij->i", [z_11, z_12]) product2 = torch.einsum("ij, ij->i", [z_21, z_22]) loss_pos = F.binary_cross_entropy_with_logits(product1, torch.ones_like(product1), reduction='sum') loss_neg = F.binary_cross_entropy_with_logits(product2, torch.zeros_like(product2), reduction='sum') return loss_pos + loss_neg
[docs]class Sign_Structure_Loss(nn.Module): def __init__(self) -> None: super().__init__()
[docs] def forward( self, z: torch.Tensor, pos_edge_index: torch.LongTensor, neg_edge_index: torch.LongTensor ) -> torch.Tensor: loss_1 = self.pos_embedding_loss(z, pos_edge_index) loss_2 = self.neg_embedding_loss(z, neg_edge_index) return loss_1 + loss_2
[docs] def pos_embedding_loss( self, z: torch.Tensor, pos_edge_index: torch.LongTensor ) -> torch.Tensor: """Computes the triplet loss between positive node pairs and sampled non-node pairs. Args: z (Tensor): The node embeddings. pos_edge_index (LongTensor): The positive edge indices. """ i, j, k = structured_negative_sampling(pos_edge_index, z.size(0)) out = (z[i] - z[j]).pow(2).sum(dim=1) - (z[i] - z[k]).pow(2).sum(dim=1) return torch.clamp(out, min=0).mean()
[docs] def neg_embedding_loss( self, z: torch.Tensor, neg_edge_index: torch.LongTensor ) -> torch.Tensor: """Computes the triplet loss between negative node pairs and sampled non-node pairs. Args: z (Tensor): The node embeddings. neg_edge_index (LongTensor): The negative edge indices. """ i, j, k = structured_negative_sampling(neg_edge_index, z.size(0)) out = (z[i] - z[k]).pow(2).sum(dim=1) - (z[i] - z[j]).pow(2).sum(dim=1) return torch.clamp(out, min=0).mean()