Source code for torch_geometric_signed_directed.nn.directed.DiGCN_Inception_Block

from typing import Tuple

import torch
from torch.nn import Linear

from .DiGCNConv import DiGCNConv


[docs]class DiGCN_InceptionBlock(torch.nn.Module): r"""An implementation of the inception block model from the `Digraph Inception Convolutional Networks <https://papers.nips.cc/paper/2020/file/cffb6e2288a630c2a787a64ccc67097c-Paper.pdf>`_ paper. Args: in_dim (int): Dimention of input. out_dim (int): Dimention of output. """ def __init__(self, in_dim: int, out_dim: int): super(DiGCN_InceptionBlock, self).__init__() self.ln = Linear(in_dim, out_dim) self.conv1 = DiGCNConv(in_dim, out_dim) self.conv2 = DiGCNConv(in_dim, out_dim) self.reset_parameters() def reset_parameters(self): self.ln.reset_parameters() self.conv1.reset_parameters() self.conv2.reset_parameters()
[docs] def forward(self, x: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: torch.FloatTensor, edge_index2: torch.LongTensor, edge_weight2: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """ Making a forward pass of the DiGCN inception block model. Arg types: * x (PyTorch FloatTensor) - Node features. * edge_index, edge_index2 (PyTorch LongTensor) - Edge indices. * edge_weight, edge_weight2 (PyTorch FloatTensor) - Edge weights corresponding to edge indices. Return types: * x0, x1, x2 (PyTorch FloatTensor) - Hidden representations. """ x0 = self.ln(x) x1 = self.conv1(x, edge_index, edge_weight) x2 = self.conv2(x, edge_index2, edge_weight2) return x0, x1, x2