|
| 1 | +from typing import NamedTuple, Optional, Tuple |
| 2 | + |
| 3 | +import torch |
| 4 | +import torch.nn.functional as F |
| 5 | +from torch import Tensor |
| 6 | + |
| 7 | +from torch_geometric.utils import ( |
| 8 | + dense_to_sparse, |
| 9 | + one_hot, |
| 10 | + to_dense_adj, |
| 11 | + to_scipy_sparse_matrix, |
| 12 | +) |
| 13 | + |
| 14 | + |
| 15 | +class UnpoolInfo(NamedTuple): |
| 16 | + edge_index: Tensor |
| 17 | + cluster: Tensor |
| 18 | + batch: Tensor |
| 19 | + |
| 20 | + |
| 21 | +class ClusterPooling(torch.nn.Module): |
| 22 | + r"""The cluster pooling operator from the `"Edge-Based Graph Component |
| 23 | + Pooling" <paper url>`_ paper. |
| 24 | +
|
| 25 | + :class:`ClusterPooling` computes a score for each edge. |
| 26 | + Based on the selected edges, graph clusters are calculated and compressed |
| 27 | + to one node using the injective :obj:`"sum"` aggregation function. |
| 28 | + Edges are remapped based on the nodes created by each cluster and the |
| 29 | + original edges. |
| 30 | +
|
| 31 | + Args: |
| 32 | + in_channels (int): Size of each input sample. |
| 33 | + edge_score_method (str, optional): The function to apply |
| 34 | + to compute the edge score from raw edge scores (:obj:`"tanh"`, |
| 35 | + :obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`) |
| 36 | + dropout (float, optional): The probability with |
| 37 | + which to drop edge scores during training. (default: :obj:`0.0`) |
| 38 | + threshold (float, optional): The threshold of edge scores. If set to |
| 39 | + :obj:`None`, will be automatically inferred depending on |
| 40 | + :obj:`edge_score_method`. (default: :obj:`None`) |
| 41 | + """ |
| 42 | + def __init__( |
| 43 | + self, |
| 44 | + in_channels: int, |
| 45 | + edge_score_method: str = 'tanh', |
| 46 | + dropout: float = 0.0, |
| 47 | + threshold: Optional[float] = None, |
| 48 | + ): |
| 49 | + super().__init__() |
| 50 | + assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax'] |
| 51 | + |
| 52 | + if threshold is None: |
| 53 | + threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0 |
| 54 | + |
| 55 | + self.in_channels = in_channels |
| 56 | + self.edge_score_method = edge_score_method |
| 57 | + self.dropout = dropout |
| 58 | + self.threshhold = threshold |
| 59 | + |
| 60 | + self.lin = torch.nn.Linear(2 * in_channels, 1) |
| 61 | + |
| 62 | + def reset_parameters(self): |
| 63 | + r"""Resets all learnable parameters of the module.""" |
| 64 | + self.lin.reset_parameters() |
| 65 | + |
| 66 | + def forward( |
| 67 | + self, |
| 68 | + x: Tensor, |
| 69 | + edge_index: Tensor, |
| 70 | + batch: Tensor, |
| 71 | + ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: |
| 72 | + r"""Forward pass. |
| 73 | +
|
| 74 | + Args: |
| 75 | + x (torch.Tensor): The node features. |
| 76 | + edge_index (torch.Tensor): The edge indices. |
| 77 | + batch (torch.Tensor): Batch vector |
| 78 | + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns |
| 79 | + each node to a specific example. |
| 80 | +
|
| 81 | + Return types: |
| 82 | + * **x** *(torch.Tensor)* - The pooled node features. |
| 83 | + * **edge_index** *(torch.Tensor)* - The coarsened edge indices. |
| 84 | + * **batch** *(torch.Tensor)* - The coarsened batch vector. |
| 85 | + * **unpool_info** *(UnpoolInfo)* - Information that can be consumed |
| 86 | + for unpooling. |
| 87 | + """ |
| 88 | + mask = edge_index[0] != edge_index[1] |
| 89 | + edge_index = edge_index[:, mask] |
| 90 | + |
| 91 | + edge_attr = torch.cat( |
| 92 | + [x[edge_index[0]], x[edge_index[1]]], |
| 93 | + dim=-1, |
| 94 | + ) |
| 95 | + edge_score = self.lin(edge_attr).view(-1) |
| 96 | + edge_score = F.dropout(edge_score, p=self.dropout, |
| 97 | + training=self.training) |
| 98 | + |
| 99 | + if self.edge_score_method == 'tanh': |
| 100 | + edge_score = edge_score.tanh() |
| 101 | + elif self.edge_score_method == 'sigmoid': |
| 102 | + edge_score = edge_score.sigmoid() |
| 103 | + else: |
| 104 | + assert self.edge_score_method == 'log_softmax' |
| 105 | + edge_score = F.log_softmax(edge_score, dim=0) |
| 106 | + |
| 107 | + return self._merge_edges(x, edge_index, batch, edge_score) |
| 108 | + |
| 109 | + def _merge_edges( |
| 110 | + self, |
| 111 | + x: Tensor, |
| 112 | + edge_index: Tensor, |
| 113 | + batch: Tensor, |
| 114 | + edge_score: Tensor, |
| 115 | + ) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]: |
| 116 | + |
| 117 | + from scipy.sparse.csgraph import connected_components |
| 118 | + |
| 119 | + edge_contract = edge_index[:, edge_score > self.threshhold] |
| 120 | + |
| 121 | + adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0)) |
| 122 | + _, cluster_np = connected_components(adj, directed=True, |
| 123 | + connection="weak") |
| 124 | + |
| 125 | + cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device) |
| 126 | + C = one_hot(cluster) |
| 127 | + A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0) |
| 128 | + S = to_dense_adj(edge_index, edge_attr=edge_score, |
| 129 | + max_num_nodes=x.size(0)).squeeze(0) |
| 130 | + |
| 131 | + A_contract = to_dense_adj(edge_contract, |
| 132 | + max_num_nodes=x.size(0)).squeeze(0) |
| 133 | + nodes_single = ((A_contract.sum(dim=-1) + |
| 134 | + A_contract.sum(dim=-2)) == 0).nonzero() |
| 135 | + S[nodes_single, nodes_single] = 1.0 |
| 136 | + |
| 137 | + x_out = (S @ C).t() @ x |
| 138 | + edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0)) |
| 139 | + batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch) |
| 140 | + unpool_info = UnpoolInfo(edge_index, cluster, batch) |
| 141 | + |
| 142 | + return x_out, edge_index_out, batch_out, unpool_info |
| 143 | + |
| 144 | + def __repr__(self) -> str: |
| 145 | + return f'{self.__class__.__name__}({self.in_channels})' |
0 commit comments