torch_geometric_signed_directed.utils.general.link_split
Functions
|
Generate edge labels based on the task. |
|
Get train/val/test dataset for the link prediction task. |
Module Contents
- undirected_label2directed_label(adj: scipy.sparse.csr_matrix, edge_pairs: List[Tuple], task: str, directed_graph: bool = True, signed_directed: bool = False) List | List
Generate edge labels based on the task. Arg types:
adj (scipy.sparse.csr_matrix) - Scipy sparse undirected adjacency matrix.
- edge_pairs (List[Tuple]) - The edge list for the link dataset querying. Each element
in the list is an edge tuple.
edge_weight (List[Tuple]) - The edge weights list for sign graphs.
- task (str): three_class_digraph (three-class link prediction); direction (direction prediction); existence (existence prediction); sign (sign prediction);
four_class_signed_digraph (directed sign prediction); five_class_signed_digraph (directed sign and existence prediction)
- Return types:
new_edge_pairs (List) - A list of edges.
- labels (List) - The labels for new_edge_pairs.
- If task == “existence”: 0 (the directed edge exists in the graph), 1 (the edge doesn’t exist).
The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “direction”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists).
The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “three_class_digraph”: 0 (the directed edge exists in the graph),
1 (the edge of the reversed direction exists), 2 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “four_class_signed_digraph”: 0 (the positive directed edge exists in the graph),
1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “five_class_signed_digraph”: 0 (the positive directed edge exists in the graph),
1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists), 4 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.
If task == “sign”: 0 (negative edge), 1 (positive edge).
- label_weight (List) - The weight list of the query edges. The weight is zero if the directed edge
doesn’t exist in both directions.
undirected (List) - The undirected edges list within the input graph.
- link_class_split(data: torch_geometric.data.Data, size: int = None, splits: int = 2, prob_test: float = 0.15, prob_val: float = 0.05, task: str = 'direction', seed: int = 0, maintain_connect: bool = True, ratio: float = 1.0, device: str = 'cpu') dict
Get train/val/test dataset for the link prediction task. Arg types:
data (torch_geometric.data.Data or DirectedData object) - The input dataset.
prob_val (float, optional) - The proportion of edges selected for validation (Default: 0.05).
prob_test (float, optional) - The proportion of edges selected for testing (Default: 0.15).
splits (int, optional) - The split size (Default: 2).
size (int, optional) - The size of the input graph. If none, the graph size is the maximum index of nodes plus 1 (Default: None).
task (str, optional) - The evaluation task: three_class_digraph (three-class link prediction); direction (direction prediction); existence (existence prediction); sign (sign prediction); four_class_signed_digraph (directed sign prediction); five_class_signed_digraph (directed sign and existence prediction) (Default: ‘direction’)
seed (int, optional) - The random seed for positve edge selection (Default: 0). Negative edges are selected by pytorch geometric negative_sampling.
maintain_connect (bool, optional) - If maintaining connectivity when removing edges for validation and testing. The connectivity is maintained by obtaining edges in the minimum spanning tree/forest first. These edges will not be removed for validation and testing (Default: True).
ratio (float, optional) - The maximum ratio of edges used for dataset generation. (Default: 1.0)
device (int, optional) - The device to hold the return value (Default: ‘cpu’).
- Return types:
- datasets - A dict include training/validation/testing splits of edges and labels. For split index i:
datasets[i][‘graph’] (torch.LongTensor): the observed edge list after removing edges for validation and testing.
datasets[i][‘train’/’val’/’testing’][‘edges’] (List): the edge list for training/validation/testing.
- datasets[i][‘train’/’val’/’testing’][‘label’] (List): the labels of edges:
If task == “existence”: 0 (the directed edge exists in the graph), 1 (the edge doesn’t exist). The undirected edges in the directed input graph are removed to avoid ambiguity.
If task == “direction”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.
If task == “three_class_digraph”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists), 2 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “four_class_signed_digraph”: 0 (the positive directed edge exists in the graph),
1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.
- If task == “five_class_signed_digraph”: 0 (the positive directed edge exists in the graph),
1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists), 4 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.
If task == “sign”: 0 (negative edge), 1 (positive edge). This is the link sign prediction task for signed networks.