Source code for torch_geometric_signed_directed.utils.general.extract_network

from typing import Tuple, Union

import numpy as np
import scipy.sparse as sp
import networkx as nx
from torch import LongTensor


[docs]def extract_network(A: sp.spmatrix, labels: Union[np.array, LongTensor, None] = None, lowest_degree: int = 2, max_iter=10) -> Tuple[sp.spmatrix, np.array]: """Find the largest connected component and iteratively only include nodes with degree at least lowest_degree, for at most max_iter iterations, from the `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://arxiv.org/pdf/2106.05194.pdf>`_ paper. Arg types: * **A** (scipy sparse matrix) - Adjacency matrix. * **labels** (numpy array or torch.LongTensor, optional) - Node labels, default None. * **lowest_degree** (int, optional) - The lowest degree for the output network, default 2. * **max_iter** (int, optional) - The maximum number of iterations. Return types: * **A** (scipy sparse matrix) - Adjacency matrix after fixing degrees and obtaining a connected netework. * **labels** (numpy array) - Node labels after fixing degrees and obtaining a connected netework. """ G = nx.from_scipy_sparse_matrix(A, create_using=nx.DiGraph) largest_cc = max(nx.weakly_connected_components(G)) A_new = A[list(largest_cc)][:, list(largest_cc)] labels_new = None if labels is not None: labels_new = labels[list(largest_cc)] G0 = nx.from_scipy_sparse_matrix(A_new, create_using=nx.DiGraph) flag = True iter_num = 0 keep = [] while flag and iter_num < max_iter: while flag and iter_num < max_iter: iter_num += 1 remove = [node for node, degree in dict( G0.degree()).items() if degree < lowest_degree] keep = np.array([node for node, degree in dict( G0.degree()).items() if degree >= lowest_degree]) if len(keep): if len(remove): G0.remove_nodes_from(remove) else: flag = False else: lowest_degree -= 1 print('Nothing to keep, reducing lowest_degree by one to be {}!'.format( lowest_degree)) G0 = nx.from_scipy_sparse_matrix( A_new, create_using=nx.DiGraph) break A_new = A[keep][:, keep] if labels is not None: labels_new = labels[keep] return A_new, labels_new