torch_geometric_signed_directed.nn.directed.DIGRAC_node_clustering

Classes

DIGRAC_node_clustering

The directed graph clustering model from the

Module Contents

class DIGRAC_node_clustering(num_features: int, hidden: int, nclass: int, fill_value: float, dropout: float, hop: int)

Bases: torch.nn.Module

The directed graph clustering model from the DIGRAC: Digraph Clustering Based on Flow Imbalance paper.

Parameters:
  • num_features (int) – Number of features.

  • hidden (int) – Hidden dimensions of the initial MLP.

  • nclass (int) – Number of clusters.

  • dropout (float) – Dropout probability.

  • hop (int) – Number of hops to consider.

  • fill_value (float) – Value for added self-loops.

dropout
forward(edge_index: torch.FloatTensor, edge_weight: torch.FloatTensor, features: torch.FloatTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.FloatTensor]

Making a forward pass of the DIGRAC node clustering model.

Arg types:
  • edge_index (PyTorch FloatTensor) - Edge indices.

  • edge_weight (PyTorch FloatTensor) - Edge weights.

  • features (PyTorch FloatTensor) - Input node features, with shape (num_nodes, num_features).

Return types:
  • z (PyTorch FloatTensor) - Embedding matrix, with shape (num_nodes, 2*hidden).

  • output (PyTorch FloatTensor) - Log of prob, with shape (num_nodes, num_clusters).

  • predictions_cluster (PyTorch LongTensor) - Predicted labels.

  • prob (PyTorch FloatTensor) - Probability assignment matrix of different clusters, with shape (num_nodes, num_clusters).