Source code for torch_geometric_signed_directed.nn.signed.SNEA

import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric_signed_directed.utils.signed import (create_spectral_features,
from .SNEAConv import SNEAConv

[docs]class SNEA(nn.Module): r"""The signed graph attentional layers operator from the `"Learning Signed Network Embedding via Graph Attention" <>`_ paper Args: node_num (int): The number of nodes. edge_index_s (LongTensor): The edgelist with sign. (e.g., torch.LongTensor([[0, 1, -1], [0, 2, 1]]) ) in_dim (int, optional): Size of each input sample features. Defaults to 64. out_dim (int, optional): Size of each output embeddings. Defaults to 64. layer_num (int, optional): Number of layers. Defaults to 2. init_emb: (FloatTensor, optional): The initial embeddings. Defaults to :obj:`None`, which will use TSVD as initial embeddings. init_emb_grad (bool, optional): Optimize initial embeddings or not. lamb (float, optional): Balances the contributions of the overall objective. (default: :obj:`4`) """ def __init__( self, node_num: int, edge_index_s: torch.LongTensor, in_dim: int = 64, out_dim: int = 64, layer_num: int = 2, init_emb: torch.FloatTensor = None, init_emb_grad: bool = True, lamb: float = 4 ): super().__init__() self.node_num = node_num self.in_dim = in_dim self.out_dim = out_dim self.lamb = lamb self.device = edge_index_s.device self.pos_edge_index = edge_index_s[edge_index_s[:, 2] > 0][:, :2].t() self.neg_edge_index = edge_index_s[edge_index_s[:, 2] < 0][:, :2].t() if init_emb is None: init_emb = create_spectral_features( pos_edge_index=self.pos_edge_index, neg_edge_index=self.neg_edge_index, node_num=self.node_num, dim=self.in_dim ).to(self.device) else: init_emb = init_emb self.x = nn.Parameter(init_emb, requires_grad=init_emb_grad) self.conv1 = SNEAConv(in_dim, out_dim // 2, first_aggr=True) self.convs = torch.nn.ModuleList() for _ in range(layer_num - 1): self.convs.append( SNEAConv(out_dim // 2, out_dim // 2, first_aggr=False)) self.weight = torch.nn.Linear(self.out_dim, self.out_dim) self.lsp_loss = Link_Sign_Entropy_Loss(out_dim) self.structure_loss = Sign_Structure_Loss() self.reset_parameters() def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.weight.reset_parameters() def loss(self) -> torch.FloatTensor: z = self.forward() nll_loss = self.lsp_loss(z, self.pos_edge_index, self.neg_edge_index) structure_loss = self.structure_loss( z, self.pos_edge_index, self.neg_edge_index) return nll_loss + self.lamb * structure_loss
[docs] def forward(self) -> Tensor: z = torch.tanh(self.conv1( self.x, self.pos_edge_index, self.neg_edge_index)) for conv in self.convs: z = torch.tanh(conv(z, self.pos_edge_index, self.neg_edge_index)) z = torch.tanh(self.weight(z)) return z