Source code for torch_geometric_signed_directed.nn.signed.SNEAConv

from typing import Optional, Union

from torch_geometric.typing import (PairTensor, OptTensor)
import torch
from torch import LongTensor, Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import (add_self_loops,

[docs]class SNEAConv(MessagePassing): r"""The signed graph attentional layers operator from the `"Learning Signed Network Embedding via Graph Attention" <>`_ paper .. math:: \mathbf{h}_{i}^{\mathcal{B}(l)}=\tanh \left(\sum_{j \in \hat{\mathcal{N}}_{i}^{+}, k \in \mathcal{N}_{i}^{-}} \alpha_{i j}^{\mathcal{B}(l)} \mathbf{h}_{j}^{\mathcal{B}(l-1)} \mathbf{W}^{\mathcal{B}(l)} +\alpha_{i k}^{\mathcal{B}(l)} \mathbf{h}_{k}^{\mathcal{U}(l-1)} \mathbf{W}^{\mathcal{B}(l)}\right) \mathbf{h}_{i}^{\mathcal{U}(l)}=\tanh \left(\sum_{j \in \hat{\mathcal{N}}_{i}^{+}, k \in \mathcal{N}_{i}^{-}} \alpha_{i j}^{\mathcal{U}(l)} \mathbf{h}_{j}^{\mathcal{U}(l-1)} \mathbf{W}^{\mathcal{U}(l)} +\alpha_{i k}^{\mathcal{U}(l)} \mathbf{h}_{k}^{\mathcal{B}(l-1)} \mathbf{W}^{\mathcal{U}(l)}\right) In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be a tensor where :obj:`x[:, :in_dim]` denotes the positive node features :math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_dim:]` denotes the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`. Args: in_dim (int or tuple): Size of each input sample, or :obj:`-1` to derive the size from the first input(s) to the forward method. A tuple corresponds to the sizes of source and target dimensionalities. out_dim (int): Size of each output sample. first_aggr (bool): Denotes which aggregation formula to use. bias (bool, optional): If set to :obj:`False`, the layer will not learn an additive bias. (default: :obj:`True`) **kwargs (optional): Additional arguments of :class:`torch_geometric.nn.conv.MessagePassing`. """ def __init__( self, in_dim: int, out_dim: int, first_aggr: bool, bias: bool = True, norm_emb: bool = True, add_self_loops=True, **kwargs ): kwargs.setdefault('aggr', 'add') super().__init__(**kwargs) self.in_dim = in_dim self.out_dim = out_dim self.first_aggr = first_aggr self.add_self_loops = add_self_loops self.norm_emb = norm_emb self.lin_b = torch.nn.Linear(in_dim, out_dim, bias) self.lin_u = torch.nn.Linear(in_dim, out_dim, bias) self.alpha_u = torch.nn.Linear(self.out_dim * 2, 1) self.alpha_b = torch.nn.Linear(self.out_dim * 2, 1) self.reset_parameters() def reset_parameters(self): self.lin_b.reset_parameters() self.lin_u.reset_parameters() torch.nn.init.xavier_normal_(self.alpha_b.weight) torch.nn.init.xavier_normal_(self.alpha_u.weight) # self.alpha_b.reset_parameters() # self.alpha_u.reset_parameters() def forward(self, x: Union[Tensor, PairTensor], pos_edge_index: LongTensor, neg_edge_index: LongTensor): """""" if self.first_aggr: h_b = self.lin_b(x) h_u = self.lin_u(x) edge, _ = remove_self_loops(pos_edge_index) edge, _ = add_self_loops(edge) edge_p = torch.zeros(edge.size(-1), dtype=torch.long) # x = torch.stack((h_b, h_b), dim=-1) x1 = h_b x2 = h_b out_b = self.propagate(edge, x1=x1, x2=x2, edge_p=edge_p, alpha_func=self.alpha_b) edge, _ = remove_self_loops(neg_edge_index) edge, _ = add_self_loops(edge) edge_p = torch.zeros(edge.size(-1), dtype=torch.long) x1 = h_u x2 = h_u out_u = self.propagate(edge, x1=x1, x2=x2, edge_p=edge_p, alpha_func=self.alpha_u) out =[out_b, out_u], dim=-1) else: F_in = self.in_dim h_b = x[..., :F_in] h_u = x[..., F_in:] edge1, _ = remove_self_loops(pos_edge_index) edge1, _ = add_self_loops(edge1) edge2, _ = remove_self_loops(neg_edge_index) edge =[edge1, edge2], dim=-1) edge_p1 = torch.zeros(edge1.size(-1), dtype=torch.long) edge_p2 = torch.ones(edge2.size(-1), dtype=torch.long) edge_p =[edge_p1, edge_p2], dim=-1) x1 = self.lin_b(h_b) x2 = self.lin_b(h_u) out_b = self.propagate(edge, x1=x1, x2=x2, edge_p=edge_p, alpha_func=self.alpha_b) edge1, _ = remove_self_loops(pos_edge_index) edge1, _ = add_self_loops(edge1) edge2, _ = remove_self_loops(neg_edge_index) edge =[edge1, edge2], dim=-1) edge_p1 = torch.zeros(edge1.size(-1), dtype=torch.long) edge_p2 = torch.ones(edge2.size(-1), dtype=torch.long) edge_p =[edge_p1, edge_p2], dim=-1) x1 = self.lin_u(h_u) x2 = self.lin_u(h_b) out_u = self.propagate(edge, x1=x1, x2=x2, edge_p=edge_p, alpha_func=self.alpha_u) out =[out_b, out_u], dim=-1) return out
[docs] def message(self, x1_j: Tensor, x2_j: Tensor, x1_i: Tensor, x2_i: Tensor, edge_p: Tensor, alpha_func, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: x1 =[x1_j, x1_i], dim=-1) x2 =[x2_j, x2_i], dim=-1) edge_h = torch.stack([x1, x2], dim=-1) edge_h = edge_h[torch.arange(edge_h.size(0)), :, edge_p] alpha = alpha_func(edge_h) alpha = torch.tanh(alpha) alpha = softmax(alpha, index, ptr, size_i) x_i = torch.stack([x1_i, x2_i], dim=-1) x_i = x_i[torch.arange(edge_h.size(0)), :, edge_p] return x_i * alpha
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_dim}, ' f'{self.out_dim}, first_aggr={self.first_aggr})')