Skip to content

Commit df3df3b

Browse files
thijssnellemanpre-commit-ci[bot]rusty1s
authored
Added the ClusterPooling layer (#9627)
Our method that we will present at ECML MLG workshop. Based on [EdgePooling](https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/pool/edge_pool.py) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey <[email protected]>
1 parent 241a8c3 commit df3df3b

File tree

5 files changed

+186
-6
lines changed

5 files changed

+186
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added the `ClusterPooling` layer ([#9627](https://github.com/pyg-team/pytorch_geometric/pull/9627))
1011
- Added the `LinkPredMRR` metric ([#9632](https://github.com/pyg-team/pytorch_geometric/pull/9632))
1112
- Added PyTorch 2.4 support ([#9594](https://github.com/pyg-team/pytorch_geometric/pull/9594))
1213
- Added `utils.normalize_edge_index` for symmetric/asymmetric normalization of graph edges ([#9554](https://github.com/pyg-team/pytorch_geometric/pull/9554))

test/nn/pool/test_cluster_pool.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import pytest
2+
import torch
3+
4+
from torch_geometric.nn import ClusterPooling
5+
from torch_geometric.testing import withPackage
6+
7+
8+
@withPackage('scipy')
9+
@pytest.mark.parametrize('edge_score_method', [
10+
'tanh',
11+
'sigmoid',
12+
'log_softmax',
13+
])
14+
def test_cluster_pooling(edge_score_method):
15+
x = torch.tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]])
16+
edge_index = torch.tensor([
17+
[0, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6],
18+
[1, 2, 3, 6, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0],
19+
])
20+
batch = torch.tensor([0, 0, 0, 0, 1, 1, 0])
21+
22+
op = ClusterPooling(in_channels=1, edge_score_method=edge_score_method)
23+
assert str(op) == 'ClusterPooling(1)'
24+
op.reset_parameters()
25+
26+
x, edge_index, batch, unpool_info = op(x, edge_index, batch)
27+
assert x.size(0) <= 7
28+
assert edge_index.size(0) == 2
29+
if edge_index.numel() > 0:
30+
assert edge_index.min() >= 0
31+
assert edge_index.max() < x.size(0)
32+
assert batch.size() == (x.size(0), )

torch_geometric/nn/pool/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
import torch_geometric.typing
88
from torch_geometric.typing import OptTensor, torch_cluster
99

10-
from .asap import ASAPooling
1110
from .avg_pool import avg_pool, avg_pool_neighbor_x, avg_pool_x
12-
from .edge_pool import EdgePooling
1311
from .glob import global_add_pool, global_max_pool, global_mean_pool
1412
from .knn import (KNNIndex, L2KNNIndex, MIPSKNNIndex, ApproxL2KNNIndex,
1513
ApproxMIPSKNNIndex)
1614
from .graclus import graclus
1715
from .max_pool import max_pool, max_pool_neighbor_x, max_pool_x
18-
from .mem_pool import MemPooling
19-
from .pan_pool import PANPooling
20-
from .sag_pool import SAGPooling
2116
from .topk_pool import TopKPooling
17+
from .sag_pool import SAGPooling
18+
from .edge_pool import EdgePooling
19+
from .cluster_pool import ClusterPooling
20+
from .asap import ASAPooling
21+
from .pan_pool import PANPooling
22+
from .mem_pool import MemPooling
2223
from .voxel_grid import voxel_grid
2324
from .approx_knn import approx_knn, approx_knn_graph
2425

@@ -344,6 +345,7 @@ def nearest(
344345
'TopKPooling',
345346
'SAGPooling',
346347
'EdgePooling',
348+
'ClusterPooling',
347349
'ASAPooling',
348350
'PANPooling',
349351
'MemPooling',
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from typing import NamedTuple, Optional, Tuple
2+
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import Tensor
6+
7+
from torch_geometric.utils import (
8+
dense_to_sparse,
9+
one_hot,
10+
to_dense_adj,
11+
to_scipy_sparse_matrix,
12+
)
13+
14+
15+
class UnpoolInfo(NamedTuple):
16+
edge_index: Tensor
17+
cluster: Tensor
18+
batch: Tensor
19+
20+
21+
class ClusterPooling(torch.nn.Module):
22+
r"""The cluster pooling operator from the `"Edge-Based Graph Component
23+
Pooling" <paper url>`_ paper.
24+
25+
:class:`ClusterPooling` computes a score for each edge.
26+
Based on the selected edges, graph clusters are calculated and compressed
27+
to one node using the injective :obj:`"sum"` aggregation function.
28+
Edges are remapped based on the nodes created by each cluster and the
29+
original edges.
30+
31+
Args:
32+
in_channels (int): Size of each input sample.
33+
edge_score_method (str, optional): The function to apply
34+
to compute the edge score from raw edge scores (:obj:`"tanh"`,
35+
:obj:`"sigmoid"`, :obj:`"log_softmax"`). (default: :obj:`"tanh"`)
36+
dropout (float, optional): The probability with
37+
which to drop edge scores during training. (default: :obj:`0.0`)
38+
threshold (float, optional): The threshold of edge scores. If set to
39+
:obj:`None`, will be automatically inferred depending on
40+
:obj:`edge_score_method`. (default: :obj:`None`)
41+
"""
42+
def __init__(
43+
self,
44+
in_channels: int,
45+
edge_score_method: str = 'tanh',
46+
dropout: float = 0.0,
47+
threshold: Optional[float] = None,
48+
):
49+
super().__init__()
50+
assert edge_score_method in ['tanh', 'sigmoid', 'log_softmax']
51+
52+
if threshold is None:
53+
threshold = 0.5 if edge_score_method == 'sigmoid' else 0.0
54+
55+
self.in_channels = in_channels
56+
self.edge_score_method = edge_score_method
57+
self.dropout = dropout
58+
self.threshhold = threshold
59+
60+
self.lin = torch.nn.Linear(2 * in_channels, 1)
61+
62+
def reset_parameters(self):
63+
r"""Resets all learnable parameters of the module."""
64+
self.lin.reset_parameters()
65+
66+
def forward(
67+
self,
68+
x: Tensor,
69+
edge_index: Tensor,
70+
batch: Tensor,
71+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
72+
r"""Forward pass.
73+
74+
Args:
75+
x (torch.Tensor): The node features.
76+
edge_index (torch.Tensor): The edge indices.
77+
batch (torch.Tensor): Batch vector
78+
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns
79+
each node to a specific example.
80+
81+
Return types:
82+
* **x** *(torch.Tensor)* - The pooled node features.
83+
* **edge_index** *(torch.Tensor)* - The coarsened edge indices.
84+
* **batch** *(torch.Tensor)* - The coarsened batch vector.
85+
* **unpool_info** *(UnpoolInfo)* - Information that can be consumed
86+
for unpooling.
87+
"""
88+
mask = edge_index[0] != edge_index[1]
89+
edge_index = edge_index[:, mask]
90+
91+
edge_attr = torch.cat(
92+
[x[edge_index[0]], x[edge_index[1]]],
93+
dim=-1,
94+
)
95+
edge_score = self.lin(edge_attr).view(-1)
96+
edge_score = F.dropout(edge_score, p=self.dropout,
97+
training=self.training)
98+
99+
if self.edge_score_method == 'tanh':
100+
edge_score = edge_score.tanh()
101+
elif self.edge_score_method == 'sigmoid':
102+
edge_score = edge_score.sigmoid()
103+
else:
104+
assert self.edge_score_method == 'log_softmax'
105+
edge_score = F.log_softmax(edge_score, dim=0)
106+
107+
return self._merge_edges(x, edge_index, batch, edge_score)
108+
109+
def _merge_edges(
110+
self,
111+
x: Tensor,
112+
edge_index: Tensor,
113+
batch: Tensor,
114+
edge_score: Tensor,
115+
) -> Tuple[Tensor, Tensor, Tensor, UnpoolInfo]:
116+
117+
from scipy.sparse.csgraph import connected_components
118+
119+
edge_contract = edge_index[:, edge_score > self.threshhold]
120+
121+
adj = to_scipy_sparse_matrix(edge_contract, num_nodes=x.size(0))
122+
_, cluster_np = connected_components(adj, directed=True,
123+
connection="weak")
124+
125+
cluster = torch.tensor(cluster_np, dtype=torch.long, device=x.device)
126+
C = one_hot(cluster)
127+
A = to_dense_adj(edge_index, max_num_nodes=x.size(0)).squeeze(0)
128+
S = to_dense_adj(edge_index, edge_attr=edge_score,
129+
max_num_nodes=x.size(0)).squeeze(0)
130+
131+
A_contract = to_dense_adj(edge_contract,
132+
max_num_nodes=x.size(0)).squeeze(0)
133+
nodes_single = ((A_contract.sum(dim=-1) +
134+
A_contract.sum(dim=-2)) == 0).nonzero()
135+
S[nodes_single, nodes_single] = 1.0
136+
137+
x_out = (S @ C).t() @ x
138+
edge_index_out, _ = dense_to_sparse((C.T @ A @ C).fill_diagonal_(0))
139+
batch_out = batch.new_empty(x_out.size(0)).scatter_(0, cluster, batch)
140+
unpool_info = UnpoolInfo(edge_index, cluster, batch)
141+
142+
return x_out, edge_index_out, batch_out, unpool_info
143+
144+
def __repr__(self) -> str:
145+
return f'{self.__class__.__name__}({self.in_channels})'

torch_geometric/nn/pool/edge_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self,
5959
in_channels: int,
6060
edge_score_method: Optional[Callable] = None,
61-
dropout: Optional[float] = 0.0,
61+
dropout: float = 0.0,
6262
add_to_edge_score: float = 0.5,
6363
):
6464
super().__init__()

0 commit comments

Comments
 (0)