Source code for torch_geometric_signed_directed.nn.directed.DIGRAC_node_clustering

from typing import Tuple

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from .DIMPA import DIMPA


[docs]class DIGRAC_node_clustering(torch.nn.Module): r"""The directed graph clustering model from the `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper. Args: num_features (int): Number of features. hidden (int): Hidden dimensions of the initial MLP. nclass (int): Number of clusters. dropout (float): Dropout probability. hop (int): Number of hops to consider. fill_value (float): Value for added self-loops. """ def __init__(self, num_features: int, hidden: int, nclass: int, fill_value: float, dropout: float, hop: int): super(DIGRAC_node_clustering, self).__init__() nh1 = hidden nh2 = hidden self._num_clusters = int(nclass) self._w_s0 = Parameter(torch.FloatTensor(num_features, nh1)) self._w_s1 = Parameter(torch.FloatTensor(nh1, nh2)) self._w_t0 = Parameter(torch.FloatTensor(num_features, nh1)) self._w_t1 = Parameter(torch.FloatTensor(nh1, nh2)) self._dimpa = DIMPA(hop, fill_value) self._relu = torch.nn.ReLU() self.dropout = torch.nn.Dropout(p=dropout) self._bias = Parameter(torch.FloatTensor(self._num_clusters)) self._W_prob = Parameter(torch.FloatTensor(2*nh2, self._num_clusters)) self._reset_parameters() def _reset_parameters(self): torch.nn.init.xavier_uniform_(self._w_s0, gain=1.414) torch.nn.init.xavier_uniform_(self._w_s1, gain=1.414) torch.nn.init.xavier_uniform_(self._w_t0, gain=1.414) torch.nn.init.xavier_uniform_(self._w_t1, gain=1.414) self._bias.data.fill_(0.0) torch.nn.init.xavier_uniform_(self._W_prob, gain=1.414)
[docs] def forward(self, edge_index: torch.FloatTensor, edge_weight: torch.FloatTensor, features: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.FloatTensor]: """ Making a forward pass of the DIGRAC node clustering model. Arg types: * **edge_index** (PyTorch FloatTensor) - Edge indices. * **edge_weight** (PyTorch FloatTensor) - Edge weights. * **features** (PyTorch FloatTensor) - Input node features, with shape (num_nodes, num_features). Return types: * **z** (PyTorch FloatTensor) - Embedding matrix, with shape (num_nodes, 2*hidden). * **output** (PyTorch FloatTensor) - Log of prob, with shape (num_nodes, num_clusters). * **predictions_cluster** (PyTorch LongTensor) - Predicted labels. * **prob** (PyTorch FloatTensor) - Probability assignment matrix of different clusters, with shape (num_nodes, num_clusters). """ # MLP x_s = torch.mm(features, self._w_s0) x_s = self._relu(x_s) x_s = self.dropout(x_s) x_s = torch.mm(x_s, self._w_s1) x_t = torch.mm(features, self._w_t0) x_t = self._relu(x_t) x_t = self.dropout(x_t) x_t = torch.mm(x_t, self._w_t1) z = self._dimpa(x_s, x_t, edge_index, edge_weight) output = torch.mm(z, self._W_prob) output = output + self._bias # to balance the difference in cluster probabilities predictions_cluster = torch.argmax(output, dim=1) prob = F.softmax(output, dim=1) output = F.log_softmax(output, dim=1) return F.normalize(z), output, predictions_cluster, prob