from typing import Optional
import torch
from torch.nn import Parameter
from torch_geometric.nn.inits import zeros, glorot
from torch_geometric.typing import OptTensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops
from ...utils.general.get_magnetic_signed_Laplacian import get_magnetic_signed_Laplacian
[docs]class MSConv(MessagePassing):
r"""Magnetic Signed Laplacian Convolution Layer from the
`MSGNN: A Spectral Graph Neural Network Based on a Novel Magnetic Signed Laplacian <https://proceedings.mlr.press/v198/he22c.html>`_ paper.
Args:
in_channels (int): Size of each input sample.
out_channels (int): Size of each output sample.
K (int): Chebyshev filter size :math:`K`.
q (float, optional): Initial value of the phase parameter, 0 <= q <= 0.25. Default: 0.25.
trainable_q (bool, optional): whether to set q to be trainable or not. (default: :obj:`False`)
normalization (str, optional): The normalization scheme for the magnetic
Laplacian (default: :obj:`sym`):
1. :obj:`None`: No normalization
:math:`\mathbf{L} = \bar{\mathbf{D}} - \mathbf{A} \odot \exp(i \Theta^{(q)})`
2. :obj:`"sym"`: Symmetric normalization
:math:`\mathbf{L} = \mathbf{I} - \bar{\mathbf{D}}^{-1/2} \mathbf{A}
\bar{\mathbf{D}}^{-1/2} \odot \exp(i \Theta^{(q)})`
`\odot` denotes the element-wise multiplication.
cached (bool, optional): If set to :obj:`True`, the layer will cache
the __norm__ matrix on first execution, and will use the
cached version for further executions.
This parameter should only be set to :obj:`True` in transductive
learning scenarios. (default: :obj:`False`)
bias (bool, optional): If set to :obj:`False`, the layer will not learn
an additive bias. (default: :obj:`True`)
absolute_degree (bool, optional): Whether to calculate the degree matrix with respect to absolute entries of the adjacency matrix. (default: :obj:`True`)
**kwargs (optional): Additional arguments of
:class:`torch_geometric.nn.conv.MessagePassing`.
"""
def __init__(self, in_channels:int, out_channels:int, K:int, q:float, trainable_q:bool,
normalization:str='sym', bias:bool=True, cached: bool=False, absolute_degree: bool=True, **kwargs):
kwargs.setdefault('aggr', 'add')
super(MSConv, self).__init__(**kwargs)
assert K > 0
assert normalization in [None, 'sym'], 'Invalid normalization'
kwargs.setdefault('flow', 'target_to_source')
self.in_channels = in_channels
self.out_channels = out_channels
self.normalization = normalization
self.cached = cached
self.trainable_q = trainable_q
self.absolute_degree = absolute_degree
if trainable_q:
self.q = Parameter(torch.Tensor(1).fill_(q))
else:
self.q = q
self.weight = Parameter(torch.Tensor(K, in_channels, out_channels))
if bias:
self.bias = Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
glorot(self.weight)
zeros(self.bias)
self.cached_result = None
self.cached_num_edges = None
self.cached_q = None
def __norm__(
self,
edge_index,
num_nodes: Optional[int],
edge_weight: OptTensor,
q: float,
normalization: Optional[str],
lambda_max,
dtype: Optional[int] = None
):
"""
Get the magnetic signed Laplacian.
Arg types:
* edge_index (PyTorch Long Tensor) - Edge indices.
* num_nodes (int, Optional) - Node features.
* edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices.
* lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian.
Return types:
* edge_index_real, edge_index_imag, edge_weight_real, edge_weight_imag (PyTorch Float Tensor) - signed directed laplacian tensor: real and imaginary edge indices and weights.
"""
edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
edge_index, edge_weight_real, edge_weight_imag = get_magnetic_signed_Laplacian(
edge_index, edge_weight, normalization, dtype, num_nodes, q, absolute_degree=self.absolute_degree
)
edge_weight_real = (2.0 * edge_weight_real) / lambda_max
edge_weight_real.masked_fill_(edge_weight_real == float("inf"), 0)
edge_index_imag = edge_index.clone()
edge_index_real, edge_weight_real = add_self_loops(
edge_index, edge_weight_real, fill_value=-1.0, num_nodes=num_nodes
)
assert edge_weight_real is not None
edge_weight_imag = (2.0 * edge_weight_imag) / lambda_max
edge_weight_imag.masked_fill_(edge_weight_imag == float("inf"), 0)
assert edge_weight_imag is not None
return edge_index_real, edge_index_imag, edge_weight_real, edge_weight_imag
[docs] def forward(
self,
x_real: torch.FloatTensor,
x_imag: torch.FloatTensor,
edge_index: torch.LongTensor,
edge_weight: OptTensor = None,
lambda_max: OptTensor = None,
) -> torch.FloatTensor:
"""
Making a forward pass of the Signed Directed Magnetic Laplacian Convolution layer.
Arg types:
* x_real, x_imag (PyTorch Float Tensor) - Node features.
* edge_index (PyTorch Long Tensor) - Edge indices.
* edge_weight (PyTorch Float Tensor, optional) - Edge weights corresponding to edge indices.
* lambda_max (optional, but mandatory if normalization is None) - Largest eigenvalue of Laplacian.
Return types:
* out_real, out_imag (PyTorch Float Tensor) - Hidden state tensor for all nodes, with shape (N_nodes, F_out).
"""
if self.trainable_q:
self.q = Parameter(torch.clamp(self.q, 0, 0.25))
if self.cached and self.cached_result is not None:
if edge_index.size(1) != self.cached_num_edges:
raise RuntimeError(
'Cached {} number of edges, but found {}. Please '
'disable the caching behavior of this layer by removing '
'the `cached=True` argument in its constructor.'.format(
self.cached_num_edges, edge_index.size(1)))
if self.q != self.cached_q:
raise RuntimeError(
'Cached q is {}, but found {} in input. Please '
'disable the caching behavior of this layer by removing '
'the `cached=True` argument in its constructor.'.format(
self.cached_q, self.q))
if not self.cached or self.cached_result is None:
self.cached_num_edges = edge_index.size(1)
if self.trainable_q:
self.cached_q = self.q.detach().item()
else:
self.cached_q = self.q
if self.normalization != 'sym' and lambda_max is None:
if self.trainable_q:
raise RuntimeError('Cannot train q while not calculating maximum eigenvalue of Laplacian!')
_, _, _, lambda_max = get_magnetic_signed_Laplacian(
edge_index, edge_weight, None, q=self.q, return_lambda_max=True, absolute_degree=self.absolute_degree
)
if lambda_max is None:
lambda_max = torch.tensor(2.0, dtype=x_real.dtype, device=x_real.device)
if not isinstance(lambda_max, torch.Tensor):
lambda_max = torch.tensor(lambda_max, dtype=x_real.dtype,
device=x_real.device)
assert lambda_max is not None
edge_index_real, edge_index_imag, norm_real, norm_imag = self.__norm__(edge_index, x_real.size(self.node_dim),
edge_weight, self.q, self.normalization,
lambda_max, dtype=x_real.dtype)
self.cached_result = edge_index_real, edge_index_imag, norm_real, norm_imag
edge_index_real, edge_index_imag, norm_real, norm_imag = self.cached_result
Tx_0_real_real = x_real
Tx_0_imag_imag = x_imag
Tx_0_imag_real = x_real
Tx_0_real_imag = x_imag
out_real_real = torch.matmul(Tx_0_real_real, self.weight[0])
out_imag_imag = torch.matmul(Tx_0_imag_imag, self.weight[0])
out_imag_real = torch.matmul(Tx_0_imag_real, self.weight[0])
out_real_imag = torch.matmul(Tx_0_real_imag, self.weight[0])
# propagate_type: (x: Tensor, norm: Tensor)
if self.weight.size(0) > 1:
Tx_1_real_real = self.propagate(edge_index_real, x=x_real, norm=norm_real, size=None)
out_real_real = out_real_real + torch.matmul(Tx_1_real_real, self.weight[1])
Tx_1_imag_imag = self.propagate(edge_index_imag, x=x_imag, norm=norm_imag, size=None)
out_imag_imag = out_imag_imag + torch.matmul(Tx_1_imag_imag, self.weight[1])
Tx_1_imag_real = self.propagate(edge_index_real, x=x_real, norm=norm_real, size=None)
out_imag_real = out_imag_real + torch.matmul(Tx_1_imag_real, self.weight[1])
Tx_1_real_imag = self.propagate(edge_index_imag, x=x_imag, norm=norm_imag, size=None)
out_real_imag = out_real_imag + torch.matmul(Tx_1_real_imag, self.weight[1])
for k in range(2, self.weight.size(0)):
Tx_2_real_real = self.propagate(edge_index_real, x=Tx_1_real_real, norm=norm_real, size=None)
Tx_2_real_real = 2. * Tx_2_real_real - Tx_0_real_real
out_real_real = out_real_real + torch.matmul(Tx_2_real_real, self.weight[k])
Tx_0_real_real, Tx_1_real_real = Tx_1_real_real, Tx_2_real_real
Tx_2_imag_imag = self.propagate(edge_index_imag, x=Tx_1_imag_imag, norm=norm_imag, size=None)
Tx_2_imag_imag = 2. * Tx_2_imag_imag - Tx_0_imag_imag
out_imag_imag = out_imag_imag + torch.matmul(Tx_2_imag_imag, self.weight[k])
Tx_0_imag_imag, Tx_1_imag_imag = Tx_1_imag_imag, Tx_2_imag_imag
Tx_2_imag_real = self.propagate(edge_index_real, x=Tx_1_imag_real, norm=norm_real, size=None)
Tx_2_imag_real = 2. * Tx_2_imag_real - Tx_0_imag_real
out_imag_real = out_imag_real + torch.matmul(Tx_2_imag_real, self.weight[k])
Tx_0_imag_real, Tx_1_imag_real = Tx_1_imag_real, Tx_2_imag_real
Tx_2_real_imag = self.propagate(edge_index_imag, x=Tx_1_real_imag, norm=norm_imag, size=None)
Tx_2_real_imag = 2. * Tx_2_real_imag - Tx_0_real_imag
out_real_imag = out_real_imag + torch.matmul(Tx_2_real_imag, self.weight[k])
Tx_0_real_imag, Tx_1_real_imag = Tx_1_real_imag, Tx_2_real_imag
out_real = out_real_real - out_imag_imag
out_imag = out_imag_real + out_real_imag
if self.bias is not None:
out_real += self.bias
out_imag += self.bias
return out_real, out_imag
[docs] def message(self, x_j, norm):
return norm.view(-1, 1) * x_j
def __repr__(self):
return '{}({}, {}, K={}, normalization={})'.format(
self.__class__.__name__, self.in_channels, self.out_channels,
self.weight.size(0), self.normalization)