torch_geometric_signed_directed.nn.directed.DIMPA
Classes
The directed mixed-path aggregation model from the |
Module Contents
- class DIMPA(hop: int, fill_value: float = 0.5)
Bases:
torch.nn.ModuleThe directed mixed-path aggregation model from the DIGRAC: Digraph Clustering Based on Flow Imbalance paper.
- Parameters:
hop (int) – Number of hops to consider.
fill_value (float, optional) – The layer computes \(\mathbf{\hat{A}}\) as \(\mathbf{A} + fill_value*\mathbf{I}\). (default:
0.5)
- conv_layer
- forward(x_s: torch.FloatTensor, x_t: torch.FloatTensor, edge_index: torch.FloatTensor, edge_weight: torch.FloatTensor) torch.FloatTensor
Making a forward pass of DIMPA.
- Arg types:
x_s (PyTorch FloatTensor) - Souce hidden representations.
x_t (PyTorch FloatTensor) - Target hidden representations.
edge_index (PyTorch FloatTensor) - Edge indices.
edge_weight (PyTorch FloatTensor) - Edge weights.
- Return types:
feat (PyTorch FloatTensor) - Embedding matrix, with shape (num_nodes, 2*input_dim).