Source code for torch_geometric_signed_directed.utils.directed.meta_graph_generation

import math

import numpy as np


[docs]def meta_graph_generation(F_style: str = 'cyclic', K: int = 4, eta: float = 0.05, ambient: bool = False, fill_val: float = 0.5) -> np.array: """The meta-graph generation function from the `DIGRAC: Digraph Clustering Based on Flow Imbalance <https://proceedings.mlr.press/v198/he22b.html>`_ paper. Arg types: * **F_style** (str) - Style of the meta-graph: 'cyclic', 'path', 'complete', 'star' or 'multipartite'. * **K** (int) - Number of clusters. * **eta** (float) - Noise parameter, 0 <= eta <= 0.5. * **ambient** (bool) - Whether there are ambient nodes. * **fill_val** (float) - Value to fill in the ambient locations. Return types: * **F** (NumPy array) - The resulting meta-graph adjacency matrix. """ if eta == 0: eta = -1 F = np.eye(K) * 0.5 # path if F_style == 'path': for i in range(K-1): j = i + 1 F[i, j] = 1 - eta F[j, i] = 1 - F[i, j] # cyclic structure elif F_style == 'cyclic': if K > 2: if ambient: for i in range(K-1): j = (i + 1) % (K-1) F[i, j] = 1 - eta F[j, i] = 1 - F[i, j] else: for i in range(K): j = (i + 1) % K F[i, j] = 1 - eta F[j, i] = 1 - F[i, j] else: if ambient: F = np.array([[0.5, 0.5], [0.5, 0.5]]) else: F = np.array([[0.5, 1-eta], [eta, 0.5]]) # complete meta-graph structure elif F_style == 'complete': if K > 2: for i in range(K-1): for j in range(i+1, K): direction = np.random.randint( 2, size=1) # random direction F[i, j] = direction * (1 - eta) + (1-direction) * eta F[j, i] = 1 - F[i, j] else: F = np.array([[0.5, 1-eta], [eta, 0.5]]) elif F_style == 'star': if K < 3: raise Exception("Sorry, star shape requires K at least 3!") if ambient and K < 4: raise Exception( "Sorry, star shape with ambient nodes requires K at least 4!") center_ind = math.floor((K-1)/2) F[center_ind, ::2] = eta # only even indices F[center_ind, 1::2] = 1-eta # only odd indices F[::2, center_ind] = 1-eta F[1::2, center_ind] = eta elif F_style == 'multipartite': if K < 3: raise Exception("Sorry, multipartite shape requires K at least 3!") if ambient: if K < 4: raise Exception( "Sorry, multipartite shape with ambient nodes requires K at least 4!") G1_ind = math.ceil((K-1)/9) G2_ind = math.ceil((K-1)*3/9)+G1_ind else: G1_ind = math.ceil(K/9) G2_ind = math.ceil(K*3/9)+G1_ind F[:G1_ind, G1_ind:G2_ind] = eta F[G1_ind:G2_ind, G2_ind:] = eta F[G2_ind:, G1_ind:G2_ind] = 1-eta F[G1_ind:G2_ind, :G1_ind] = 1-eta else: raise Exception("Sorry, please give correct F style string!") if ambient: F[-1, :] = 0 F[:, -1] = 0 F[F == 0] = fill_val F[F == -1] = 0 F[F == 2] = 1 return F