torch_geometric_signed_directed.nn.signed.SSSNET_link_prediction
Classes
The signed graph link prediction model adapted from the |
Module Contents
- class SSSNET_link_prediction(nfeat: int, hidden: int, nclass: int, dropout: float, hop: int, fill_value: float, directed: bool = False, bias: bool = True)
Bases:
torch.nn.ModuleThe signed graph link prediction model adapted from the SSSNET: Semi-Supervised Signed Network Clustering paper.
- Parameters:
nfeat (int) – Number of features.
hidden (int) – Hidden dimensions of the initial MLP.
nclass (int) – Number of link classes.
dropout (float) – Dropout probability.
hop (int) – Number of hops to consider.
fill_value (float) – Value for added self-loops for the positive part of the adjacency matrix.
directed (bool, optional) – Whether the input network is directed or not. (default:
False)bias (bool, optional) – If set to
False, the layer will not learn an additive bias. (default:True)
- forward(edge_index_p: torch.LongTensor, edge_weight_p: torch.FloatTensor, edge_index_n: torch.LongTensor, edge_weight_n: torch.FloatTensor, features: torch.FloatTensor, query_edges: torch.LongTensor) Tuple[torch.FloatTensor, torch.FloatTensor, torch.LongTensor, torch.FloatTensor]
Making a forward pass of the SSSNET.
- Arg types:
edge_index_p, edge_index_n (PyTorch FloatTensor) - Edge indices for positive and negative parts.
edge_weight_p, edge_weight_n (PyTorch FloatTensor) - Edge weights for positive and nagative parts.
features (PyTorch FloatTensor) - Input node features, with shape (num_nodes, num_features).
query_edges (PyTorch Long Tensor) - Edge indices for querying labels.
- Return types:
log_prob (PyTorch Float Tensor) - Logarithmic class probabilities for all nodes, with shape (num_nodes, num_classes).