Logo
  • Installation
  • Introduction
  • Citing
  • Data Structures
    • Data Classes
    • Benchmark Datasets
    • Node Splitting
    • Edge Splitting
  • Real-World Data Set Descriptions
    • Directed Unsigned Real-World Data Sets
    • Signed Real-World Data Sets
  • Case Study Examples
    • Case Study on Signed Networks
    • Case Study on Directed Networks
  • External Resources
    • External Resources - Architectures
    • External Resources - Synthetic Data Generators
    • External Resources - Real-World Data Sets
  • PyTorch Geometric Signed Directed Models
    • Directed (Unsigned) Network Models and Layers
    • Signed (Directed) Network Models and Layers
    • Auxiliary Methods and Layers
  • PyTorch Geometric Signed Directed Data Generators and Data Loaders
    • Data Classes
    • Data Generators
    • Data Loaders
  • PyTorch Geometric Signed Directed Utils
    • Task-Specific Objectives and Evaluation Methods
    • Utilities and Preprocessing Methods
PyTorch Geometric Signed Directed
  • torch_geometric_signed_directed.utils.general.link_split
  • View page source

torch_geometric_signed_directed.utils.general.link_split

Functions

undirected_label2directed_label(→ Union[List, List])

Generate edge labels based on the task.

link_class_split(→ dict)

Get train/val/test dataset for the link prediction task.

Module Contents

undirected_label2directed_label(adj: scipy.sparse.csr_matrix, edge_pairs: List[Tuple], task: str, directed_graph: bool = True, signed_directed: bool = False) → List | List

Generate edge labels based on the task. Arg types:

  • adj (scipy.sparse.csr_matrix) - Scipy sparse undirected adjacency matrix.

  • edge_pairs (List[Tuple]) - The edge list for the link dataset querying. Each element

    in the list is an edge tuple.

  • edge_weight (List[Tuple]) - The edge weights list for sign graphs.

  • task (str): three_class_digraph (three-class link prediction); direction (direction prediction); existence (existence prediction); sign (sign prediction);

    four_class_signed_digraph (directed sign prediction); five_class_signed_digraph (directed sign and existence prediction)

Return types:
  • new_edge_pairs (List) - A list of edges.

  • labels (List) - The labels for new_edge_pairs.
    • If task == “existence”: 0 (the directed edge exists in the graph), 1 (the edge doesn’t exist).

      The undirected edges in the directed input graph are removed to avoid ambiguity.

    • If task == “direction”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists).

      The undirected edges in the directed input graph are removed to avoid ambiguity.

    • If task == “three_class_digraph”: 0 (the directed edge exists in the graph),

      1 (the edge of the reversed direction exists), 2 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.

    • If task == “four_class_signed_digraph”: 0 (the positive directed edge exists in the graph),

      1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.

    • If task == “five_class_signed_digraph”: 0 (the positive directed edge exists in the graph),

      1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists), 4 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.

    • If task == “sign”: 0 (negative edge), 1 (positive edge).

  • label_weight (List) - The weight list of the query edges. The weight is zero if the directed edge

    doesn’t exist in both directions.

  • undirected (List) - The undirected edges list within the input graph.

link_class_split(data: torch_geometric.data.Data, size: int = None, splits: int = 2, prob_test: float = 0.15, prob_val: float = 0.05, task: str = 'direction', seed: int = 0, maintain_connect: bool = True, ratio: float = 1.0, device: str = 'cpu') → dict

Get train/val/test dataset for the link prediction task. Arg types:

  • data (torch_geometric.data.Data or DirectedData object) - The input dataset.

  • prob_val (float, optional) - The proportion of edges selected for validation (Default: 0.05).

  • prob_test (float, optional) - The proportion of edges selected for testing (Default: 0.15).

  • splits (int, optional) - The split size (Default: 2).

  • size (int, optional) - The size of the input graph. If none, the graph size is the maximum index of nodes plus 1 (Default: None).

  • task (str, optional) - The evaluation task: three_class_digraph (three-class link prediction); direction (direction prediction); existence (existence prediction); sign (sign prediction); four_class_signed_digraph (directed sign prediction); five_class_signed_digraph (directed sign and existence prediction) (Default: ‘direction’)

  • seed (int, optional) - The random seed for positve edge selection (Default: 0). Negative edges are selected by pytorch geometric negative_sampling.

  • maintain_connect (bool, optional) - If maintaining connectivity when removing edges for validation and testing. The connectivity is maintained by obtaining edges in the minimum spanning tree/forest first. These edges will not be removed for validation and testing (Default: True).

  • ratio (float, optional) - The maximum ratio of edges used for dataset generation. (Default: 1.0)

  • device (int, optional) - The device to hold the return value (Default: ‘cpu’).

Return types:
  • datasets - A dict include training/validation/testing splits of edges and labels. For split index i:
    • datasets[i][‘graph’] (torch.LongTensor): the observed edge list after removing edges for validation and testing.

    • datasets[i][‘train’/’val’/’testing’][‘edges’] (List): the edge list for training/validation/testing.

    • datasets[i][‘train’/’val’/’testing’][‘label’] (List): the labels of edges:
      • If task == “existence”: 0 (the directed edge exists in the graph), 1 (the edge doesn’t exist). The undirected edges in the directed input graph are removed to avoid ambiguity.

      • If task == “direction”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.

      • If task == “three_class_digraph”: 0 (the directed edge exists in the graph), 1 (the edge of the reversed direction exists), 2 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.

      • If task == “four_class_signed_digraph”: 0 (the positive directed edge exists in the graph),

        1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists). The undirected edges in the directed input graph are removed to avoid ambiguity.

      • If task == “five_class_signed_digraph”: 0 (the positive directed edge exists in the graph),

        1 (the negative directed edge exists in the graph), 2 (the positive edge of the reversed direction exists), 3 (the edge of the reversed direction exists), 4 (the edge doesn’t exist in both directions). The undirected edges in the directed input graph are removed to avoid ambiguity.

      • If task == “sign”: 0 (negative edge), 1 (positive edge). This is the link sign prediction task for signed networks.


© Copyright 2026, Yixuan He.

Built with Sphinx using a theme provided by Read the Docs.