Source code for torch_geometric_signed_directed.nn.directed.DIMPA

import torch
from torch.nn.parameter import Parameter

from ..general.conv_base import Conv_Base


[docs]class DIMPA(torch.nn.Module): r"""The directed mixed-path aggregation model from the `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper. Args: hop (int): Number of hops to consider. fill_value (float, optional): The layer computes :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + fill_value*\mathbf{I}`. (default: :obj:`0.5`) """ def __init__(self, hop: int, fill_value: float = 0.5): super(DIMPA, self).__init__() self._hop = hop self._w_s = Parameter(torch.FloatTensor(hop + 1, 1)) self._w_t = Parameter(torch.FloatTensor(hop + 1, 1)) self.conv_layer = Conv_Base(fill_value) self._reset_parameters() def _reset_parameters(self): self._w_s.data.fill_(1.0) self._w_t.data.fill_(1.0)
[docs] def forward(self, x_s: torch.FloatTensor, x_t: torch.FloatTensor, edge_index: torch.FloatTensor, edge_weight: torch.FloatTensor) -> torch.FloatTensor: """ Making a forward pass of DIMPA. Arg types: * **x_s** (PyTorch FloatTensor) - Souce hidden representations. * **x_t** (PyTorch FloatTensor) - Target hidden representations. * **edge_index** (PyTorch FloatTensor) - Edge indices. * **edge_weight** (PyTorch FloatTensor) - Edge weights. Return types: * **feat** (PyTorch FloatTensor) - Embedding matrix, with shape (num_nodes, 2*input_dim). """ feat_s = self._w_s[0]*x_s feat_t = self._w_t[0]*x_t curr_s = x_s.clone() curr_t = x_t.clone() edge_index_t = edge_index[[1, 0]] for h in range(1, 1+self._hop): curr_s = self.conv_layer(curr_s, edge_index, edge_weight) curr_t = self.conv_layer(curr_t, edge_index_t, edge_weight) feat_s += self._w_s[h]*curr_s feat_t += self._w_t[h]*curr_t feat = torch.cat([feat_s, feat_t], dim=1) # concatenate results return feat