Skip to content

Commit 5f4a21c

Browse files
authored
Memory-efficient one_hot implementation (#7005)
1 parent 86af56c commit 5f4a21c

File tree

15 files changed

+94
-41
lines changed

15 files changed

+94
-41
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
### Added
99

10+
- Added a memory-efficient `utils.one_hot` implementation ([#7005](https://github.com/pyg-team/pytorch_geometric/pull/7005))
1011
- Added `HeteroDictLinear` and an optimized `FastHGTConv` module ([#6178](https://github.com/pyg-team/pytorch_geometric/pull/6178), [#6998](https://github.com/pyg-team/pytorch_geometric/pull/6998))
1112
- Added the `DenseGATConv` module ([#6928](https://github.com/pyg-team/pytorch_geometric/pull/6928))
1213
- Added `trim_to_layer` utility function for more efficient `NeighborLoader` use-cases ([#6661](https://github.com/pyg-team/pytorch_geometric/pull/6661))

test/nn/conv/test_wl_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import torch
2-
import torch.nn.functional as F
32
from torch_sparse import SparseTensor
43

54
from torch_geometric.nn import WLConv
5+
from torch_geometric.utils import one_hot
66

77

88
def test_wl_conv():
99
x1 = torch.tensor([1, 0, 0, 1])
10-
x2 = F.one_hot(x1).to(torch.float)
10+
x2 = one_hot(x1)
1111
edge_index = torch.tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]])
1212
adj1 = SparseTensor.from_edge_index(edge_index)
1313
adj2 = adj1.to_torch_sparse_csc_tensor()

test/utils/test_one_hot.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
from torch_geometric.utils import one_hot
4+
5+
6+
def test_one_hot():
7+
index = torch.tensor([0, 1, 2])
8+
9+
out = one_hot(index)
10+
assert out.size() == (3, 3)
11+
assert out.dtype == torch.float
12+
assert out.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
13+
14+
out = one_hot(index, num_classes=4, dtype=torch.long)
15+
assert out.size() == (3, 4)
16+
assert out.dtype == torch.long
17+
assert out.tolist() == [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0]]

torch_geometric/datasets/ged_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from typing import Callable, List, Optional
66

77
import torch
8-
import torch.nn.functional as F
98

109
from torch_geometric.data import (
1110
Data,
@@ -14,7 +13,7 @@
1413
extract_tar,
1514
extract_zip,
1615
)
17-
from torch_geometric.utils import to_undirected
16+
from torch_geometric.utils import one_hot, to_undirected
1817

1918

2019
class GEDDataset(InMemoryDataset):
@@ -201,8 +200,7 @@ def process(self):
201200
x = torch.zeros(data.num_nodes, dtype=torch.long)
202201
for node, info in G.nodes(data=True):
203202
x[int(node)] = self.types.index(info['type'])
204-
data.x = F.one_hot(x, num_classes=len(self.types)).to(
205-
torch.float)
203+
data.x = one_hot(x, num_classes=len(self.types))
206204

207205
if self.pre_filter is not None and not self.pre_filter(data):
208206
continue

torch_geometric/datasets/linkx_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44
import numpy as np
55
import torch
6-
import torch.nn.functional as F
76

87
from torch_geometric.data import Data, InMemoryDataset, download_url
8+
from torch_geometric.utils import one_hot
99

1010

1111
class LINKXDataset(InMemoryDataset):
@@ -132,7 +132,7 @@ def _process_facebook(self):
132132
x = torch.cat([metadata[:, :1], metadata[:, 2:]], dim=-1)
133133
for i in range(x.size(1)):
134134
_, out = x[:, i].unique(return_inverse=True)
135-
xs.append(F.one_hot(out).to(torch.float))
135+
xs.append(one_hot(out))
136136
x = torch.cat(xs, dim=-1)
137137

138138
data = Data(x=x, edge_index=edge_index, y=y)

torch_geometric/datasets/qm9.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Callable, List, Optional
55

66
import torch
7-
import torch.nn.functional as F
87
from tqdm import tqdm
98

109
from torch_geometric.data import (
@@ -13,7 +12,7 @@
1312
download_url,
1413
extract_zip,
1514
)
16-
from torch_geometric.utils import scatter
15+
from torch_geometric.utils import one_hot, scatter
1716

1817
HAR2EV = 27.211386246
1918
KCALMOL2EV = 0.04336414
@@ -271,8 +270,7 @@ def process(self):
271270

272271
edge_index = torch.tensor([row, col], dtype=torch.long)
273272
edge_type = torch.tensor(edge_type, dtype=torch.long)
274-
edge_attr = F.one_hot(edge_type,
275-
num_classes=len(bonds)).to(torch.float)
273+
edge_attr = one_hot(edge_type, num_classes=len(bonds))
276274

277275
perm = (edge_index[0] * N + edge_index[1]).argsort()
278276
edge_index = edge_index[:, perm]
@@ -283,10 +281,10 @@ def process(self):
283281
hs = (z == 1).to(torch.float)
284282
num_hs = scatter(hs[row], col, dim_size=N, reduce='sum').tolist()
285283

286-
x1 = F.one_hot(torch.tensor(type_idx), num_classes=len(types))
284+
x1 = one_hot(torch.tensor(type_idx), num_classes=len(types))
287285
x2 = torch.tensor([atomic_number, aromatic, sp, sp2, sp3, num_hs],
288286
dtype=torch.float).t().contiguous()
289-
x = torch.cat([x1.to(torch.float), x2], dim=-1)
287+
x = torch.cat([x1, x2], dim=-1)
290288

291289
y = target[i].unsqueeze(0)
292290
name = mol.GetProp('_Name')

torch_geometric/io/sdf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import torch
2-
import torch.nn.functional as F
32

43
from torch_geometric.data import Data
54
from torch_geometric.io import parse_txt_array
6-
from torch_geometric.utils import coalesce
5+
from torch_geometric.utils import coalesce, one_hot
76

87
elems = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}
98

@@ -15,7 +14,7 @@ def parse_sdf(src):
1514
atom_block = src[1:num_atoms + 1]
1615
pos = parse_txt_array(atom_block, end=3)
1716
x = torch.tensor([elems[item.split()[3]] for item in atom_block])
18-
x = F.one_hot(x, num_classes=len(elems))
17+
x = one_hot(x, num_classes=len(elems))
1918

2019
bond_block = src[1 + num_atoms:1 + num_atoms + num_bonds]
2120
row, col = parse_txt_array(bond_block, end=2, dtype=torch.long).t() - 1

torch_geometric/io/tu.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44

55
import numpy as np
66
import torch
7-
import torch.nn.functional as F
87

98
from torch_geometric.data import Data
109
from torch_geometric.io import read_txt_array
11-
from torch_geometric.utils import coalesce, remove_self_loops
10+
from torch_geometric.utils import coalesce, one_hot, remove_self_loops
1211

1312
names = [
1413
'A', 'graph_indicator', 'node_labels', 'node_attributes'
@@ -36,8 +35,11 @@ def read_tu_data(folder, prefix):
3635
node_labels = node_labels.unsqueeze(-1)
3736
node_labels = node_labels - node_labels.min(dim=0)[0]
3837
node_labels = node_labels.unbind(dim=-1)
39-
node_labels = [F.one_hot(x, num_classes=-1) for x in node_labels]
40-
node_labels = torch.cat(node_labels, dim=-1).to(torch.float)
38+
node_labels = [one_hot(x) for x in node_labels]
39+
if len(node_labels) == 1:
40+
node_labels = node_labels[0]
41+
else:
42+
node_labels = torch.cat(node_labels, dim=-1)
4143

4244
edge_attributes = torch.empty((edge_index.size(1), 0))
4345
if 'edge_attributes' in names:
@@ -52,8 +54,11 @@ def read_tu_data(folder, prefix):
5254
edge_labels = edge_labels.unsqueeze(-1)
5355
edge_labels = edge_labels - edge_labels.min(dim=0)[0]
5456
edge_labels = edge_labels.unbind(dim=-1)
55-
edge_labels = [F.one_hot(e, num_classes=-1) for e in edge_labels]
56-
edge_labels = torch.cat(edge_labels, dim=-1).to(torch.float)
57+
edge_labels = [one_hot(e) for e in edge_labels]
58+
if len(edge_labels) == 1:
59+
edge_labels = edge_labels[0]
60+
else:
61+
edge_labels = torch.cat(edge_labels, dim=-1)
5762

5863
x = cat([node_attributes, node_labels])
5964
edge_attr = cat([edge_attributes, edge_labels])

torch_geometric/nn/conv/rgcn_conv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Tuple, Union
22

33
import torch
4-
import torch.nn.functional as F
54
from torch import Tensor
65
from torch.nn import Parameter
76
from torch.nn import Parameter as Param
@@ -16,7 +15,7 @@
1615
pyg_lib,
1716
torch_sparse,
1817
)
19-
from torch_geometric.utils import index_sort, scatter, spmm
18+
from torch_geometric.utils import index_sort, one_hot, scatter, spmm
2019
from torch_geometric.utils.sparse import index2ptr
2120

2221

@@ -351,7 +350,7 @@ def aggregate(self, inputs: Tensor, edge_type: Tensor, index: Tensor,
351350

352351
# Compute normalization in separation for each `edge_type`.
353352
if self.aggr == 'mean':
354-
norm = F.one_hot(edge_type, self.num_relations).to(torch.float)
353+
norm = one_hot(edge_type, self.num_relations, dtype=inputs.dtype)
355354
norm = scatter(norm, index, dim=0, dim_size=dim_size)[index]
356355
norm = torch.gather(norm, 1, edge_type.view(-1, 1))
357356
norm = 1. / norm.clamp_(1.)

torch_geometric/nn/models/correct_and_smooth.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
2-
import torch.nn.functional as F
32
from torch import Tensor
43

54
from torch_geometric.nn.models import LabelPropagation
65
from torch_geometric.typing import Adj, OptTensor
6+
from torch_geometric.utils import one_hot
77

88

99
class CorrectAndSmooth(torch.nn.Module):
@@ -97,8 +97,8 @@ def correct(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,
9797
assert y_true.size(0) == numel
9898

9999
if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
100-
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1))
101-
y_true = y_true.to(y_soft.dtype)
100+
y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1),
101+
dtype=y_soft.dtype)
102102

103103
error = torch.zeros_like(y_soft)
104104
error[mask] = y_true - y_soft[mask]
@@ -141,8 +141,8 @@ def smooth(self, y_soft: Tensor, y_true: Tensor, mask: Tensor,
141141
assert y_true.size(0) == numel
142142

143143
if y_true.dtype == torch.long and y_true.size(0) == y_true.numel():
144-
y_true = F.one_hot(y_true.view(-1), y_soft.size(-1))
145-
y_true = y_true.to(y_soft.dtype)
144+
y_true = one_hot(y_true.view(-1), num_classes=y_soft.size(-1),
145+
dtype=y_soft.dtype)
146146

147147
y_soft = y_soft.clone()
148148
y_soft[mask] = y_true

0 commit comments

Comments
 (0)