import torch
import torch.nn as nn
from torch import Tensor
from torch_geometric_signed_directed.utils.signed import (create_spectral_features,
Link_Sign_Entropy_Loss,
Sign_Structure_Loss)
from .SNEAConv import SNEAConv
[docs]class SNEA(nn.Module):
r"""The signed graph attentional layers operator from the `"Learning Signed
Network Embedding via Graph Attention" <https://ojs.aaai.org/index.php/AAAI/article/view/5911>`_ paper
Args:
node_num (int): The number of nodes.
edge_index_s (LongTensor): The edgelist with sign. (e.g., torch.LongTensor([[0, 1, -1], [0, 2, 1]]) )
in_dim (int, optional): Size of each input sample features. Defaults to 64.
out_dim (int, optional): Size of each output embeddings. Defaults to 64.
layer_num (int, optional): Number of layers. Defaults to 2.
init_emb: (FloatTensor, optional): The initial embeddings. Defaults to :obj:`None`, which will use TSVD as initial embeddings.
init_emb_grad (bool, optional): Optimize initial embeddings or not.
lamb (float, optional): Balances the contributions of the overall
objective. (default: :obj:`4`)
"""
def __init__(
self,
node_num: int,
edge_index_s: torch.LongTensor,
in_dim: int = 64,
out_dim: int = 64,
layer_num: int = 2,
init_emb: torch.FloatTensor = None,
init_emb_grad: bool = True,
lamb: float = 4
):
super().__init__()
self.node_num = node_num
self.in_dim = in_dim
self.out_dim = out_dim
self.lamb = lamb
self.device = edge_index_s.device
self.pos_edge_index = edge_index_s[edge_index_s[:, 2] > 0][:, :2].t()
self.neg_edge_index = edge_index_s[edge_index_s[:, 2] < 0][:, :2].t()
if init_emb is None:
init_emb = create_spectral_features(
pos_edge_index=self.pos_edge_index,
neg_edge_index=self.neg_edge_index,
node_num=self.node_num,
dim=self.in_dim
).to(self.device)
else:
init_emb = init_emb
self.x = nn.Parameter(init_emb, requires_grad=init_emb_grad)
self.conv1 = SNEAConv(in_dim, out_dim // 2,
first_aggr=True)
self.convs = torch.nn.ModuleList()
for _ in range(layer_num - 1):
self.convs.append(
SNEAConv(out_dim // 2, out_dim // 2,
first_aggr=False))
self.weight = torch.nn.Linear(self.out_dim, self.out_dim)
self.lsp_loss = Link_Sign_Entropy_Loss(out_dim)
self.structure_loss = Sign_Structure_Loss()
self.reset_parameters()
def reset_parameters(self):
self.conv1.reset_parameters()
for conv in self.convs:
conv.reset_parameters()
self.weight.reset_parameters()
def loss(self) -> torch.FloatTensor:
z = self.forward()
nll_loss = self.lsp_loss(z, self.pos_edge_index, self.neg_edge_index)
structure_loss = self.structure_loss(
z, self.pos_edge_index, self.neg_edge_index)
return nll_loss + self.lamb * structure_loss
[docs] def forward(self) -> Tensor:
z = torch.tanh(self.conv1(
self.x, self.pos_edge_index, self.neg_edge_index))
for conv in self.convs:
z = torch.tanh(conv(z, self.pos_edge_index, self.neg_edge_index))
z = torch.tanh(self.weight(z))
return z