Source code for torch_geometric_signed_directed.data.signed.SDGNN_real_data

from typing import Optional, Callable
import os
import json

import torch
from torch_geometric.data import (InMemoryDataset, download_url, Data)


dataset_name_url_dic = {
    "bitcoin_alpha": "https://github.com/SherylHYX/pytorch_geometric_signed_directed/raw/main/datasets/bitcoin_alpha.csv",
    "bitcoin_otc": "https://github.com/SherylHYX/pytorch_geometric_signed_directed/raw/main/datasets/bitcoin_otc.csv",
    "wiki": "https://github.com/SherylHYX/pytorch_geometric_signed_directed/raw/main/datasets/wikirfa.csv",
    "epinions": "https://github.com/SherylHYX/pytorch_geometric_signed_directed/raw/main/datasets/epinions.csv",
    "slashdot": "https://github.com/SherylHYX/pytorch_geometric_signed_directed/raw/main/datasets/slashdot.csv"
}


[docs]class SDGNN_real_data(InMemoryDataset): r"""Signed Directed Graph from the `"SDGNN: Learning Node Representation for Signed Directed Networks" <https://arxiv.org/abs/2101.02390>`_ paper, consising of five different datasets: Bitcoin-Alpha, Bitcoin-OTC, Wikirfa, Slashdot and Epinions from `snap.stanford.edu <http://snap.stanford.edu/data/#signnets>`_. Args: name (str): Name of the dataset, choices are: 'bitcoin_alpha', 'bitcoin_otc', 'wiki', 'epinions', 'slashdot'. root (str): Root directory where the dataset should be saved. transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access. (default: :obj:`None`) pre_transform (callable, optional): A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk. (default: :obj:`None`) """ def __init__(self, name: str, root: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None): self.name = name.lower() self.url = dataset_name_url_dic[name] self.root = root super().__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return os.path.join(self.root, self.name, 'raw') @property def processed_dir(self) -> str: return os.path.join(self.root, self.name, 'processed') @property def raw_file_names(self) -> str: _, _, filename = self.url.rpartition('/') return filename @property def processed_file_names(self) -> str: return 'data.pt'
[docs] def download(self): download_url(self.url, self.raw_dir)
[docs] def process(self): data = [] edge_weight = [] edge_index = [] node_map = {} with open(self.raw_paths[0], 'r') as f: for line in f: x = line.strip().split(',') assert len(x) == 3 a, b = x[0], x[1] if a not in node_map: node_map[a] = len(node_map) if b not in node_map: node_map[b] = len(node_map) a, b = node_map[a], node_map[b] data.append([a, b]) edge_weight.append(float(x[2])) edge_index = [[i[0], int(i[1])] for i in data] edge_index = torch.tensor(edge_index, dtype=torch.long) edge_index = edge_index.t().contiguous() edge_weight = torch.FloatTensor(edge_weight) map_file = os.path.join(self.processed_dir, 'node_id_map.json') with open(map_file, 'w') as f: f.write(json.dumps(node_map)) data = Data(edge_index=edge_index, edge_weight=edge_weight) if self.pre_transform is not None: data = self.pre_transform(data) data, slices = self.collate([data]) torch.save((data, slices), self.processed_paths[0])
@property def num_nodes(self) -> int: return self.data.edge_index.max().item() + 1