Skip to content

Commit 74ed43b

Browse files
committed
update
1 parent 1b195a0 commit 74ed43b

File tree

12 files changed

+111
-60
lines changed

12 files changed

+111
-60
lines changed

test/nn/conv/test_graph_conv.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,17 @@ def test_graph_conv():
9393
assert torch.allclose(jit((x1, None), adj3.t()), out2, atol=1e-6)
9494
assert torch.allclose(jit((x1, x2), adj4.t()), out3, atol=1e-6)
9595
assert torch.allclose(jit((x1, None), adj4.t()), out4, atol=1e-6)
96+
97+
98+
class EdgeGraphConv(GraphConv):
99+
def message(self, x_j, edge_weight):
100+
return edge_weight.view(-1, 1) * x_j
101+
102+
103+
def test_inheritance():
104+
x = torch.randn(4, 8)
105+
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
106+
edge_weight = torch.rand(4)
107+
108+
conv = EdgeGraphConv(8, 16)
109+
assert conv(x, edge_index, edge_weight).size() == (4, 16)

test/nn/conv/test_message_passing.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class MyConvWithSelfLoops(MessagePassing):
6868
def __init__(self, aggr: str = 'add'):
6969
super().__init__(aggr=aggr)
7070

71-
def forward(self, x: Tensor, edge_index: torch.Tensor) -> Tensor:
71+
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
7272
edge_index, _ = add_self_loops(edge_index)
7373

7474
# propagate_type: (x: Tensor)
@@ -144,6 +144,25 @@ def test_my_conv_out_of_bounds():
144144
conv(x, edge_index, value)
145145

146146

147+
class MyCommentedConv(MessagePassing):
148+
r"""This layer calls `self.propagate()` internally."""
149+
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
150+
# `self.propagate()` is used here to propagate messages.
151+
return self.propagate(edge_index, x=x)
152+
153+
154+
def test_my_commented_conv():
155+
# Check that `self.propagate` occurences in comments are correctly ignored.
156+
x = torch.randn(4, 8)
157+
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
158+
159+
conv = MyCommentedConv()
160+
conv(x, edge_index)
161+
162+
jit = torch.jit.script(conv)
163+
jit(x, edge_index)
164+
165+
147166
def test_my_conv_jit():
148167
x1 = torch.randn(4, 8)
149168
x2 = torch.randn(2, 16)

torch_geometric/data/dataset.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import re
44
import sys
55
import warnings
6-
from abc import ABC, abstractmethod
76
from collections.abc import Sequence
87
from typing import (
98
Any,
@@ -27,7 +26,7 @@
2726
MISSING = '???'
2827

2928

30-
class Dataset(torch.utils.data.Dataset, ABC):
29+
class Dataset(torch.utils.data.Dataset):
3130
r"""Dataset base class for creating graph datasets.
3231
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/
3332
create_dataset.html>`__ for the accompanying tutorial.
@@ -79,12 +78,10 @@ def process(self) -> None:
7978
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
8079
raise NotImplementedError
8180

82-
@abstractmethod
8381
def len(self) -> int:
8482
r"""Returns the number of data objects stored in the dataset."""
8583
raise NotImplementedError
8684

87-
@abstractmethod
8885
def get(self, idx: int) -> BaseData:
8986
r"""Gets the data object at index :obj:`idx`."""
9087
raise NotImplementedError

torch_geometric/data/in_memory_dataset.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import copy
22
import os.path as osp
33
import warnings
4-
from abc import ABC
54
from typing import (
65
Any,
76
Callable,
@@ -30,7 +29,7 @@
3029
from torch_geometric.io import fs
3130

3231

33-
class InMemoryDataset(Dataset, ABC):
32+
class InMemoryDataset(Dataset):
3433
r"""Dataset base class for creating graph datasets which easily fit
3534
into CPU memory.
3635
See `here <https://pytorch-geometric.readthedocs.io/en/latest/tutorial/

torch_geometric/inspector.py

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import typing
55
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union
66

7+
import torch
78
from torch import Tensor
89

910

@@ -32,9 +33,27 @@ def __init__(self, cls: Type):
3233
self._signature_dict: Dict[str, Signature] = {}
3334
self._source_dict: Dict[str, str] = {}
3435

36+
def _get_modules(self, cls: Type) -> List[str]:
37+
from torch_geometric.nn import MessagePassing
38+
39+
modules: List[str] = []
40+
for base_cls in cls.__bases__:
41+
if base_cls not in {object, torch.nn.Module, MessagePassing}:
42+
modules.extend(self._get_modules(base_cls))
43+
44+
modules.append(cls.__module__)
45+
return modules
46+
47+
@property
48+
def _modules(self) -> List[str]:
49+
return self._get_modules(self._cls)
50+
3551
@property
3652
def _globals(self) -> Dict[str, Any]:
37-
return sys.modules[self._cls.__module__].__dict__
53+
out: Dict[str, Any] = {}
54+
for module in self._modules:
55+
out.update(sys.modules[module].__dict__)
56+
return out
3857

3958
def __repr__(self) -> str:
4059
return f'{self.__class__.__name__}({self._cls.__name__})'
@@ -301,17 +320,6 @@ def collect_param_data(
301320

302321
# Inspecting Method Bodies ################################################
303322

304-
@property
305-
def can_read_source(self) -> bool:
306-
r"""Returns :obj:`True` if able to read the source file of the
307-
inspected class.
308-
"""
309-
try:
310-
inspect.getfile(self._cls)
311-
return True
312-
except Exception:
313-
return False
314-
315323
def get_source(self, cls: Optional[Type] = None) -> str:
316324
r"""Returns the source code of :obj:`cls`."""
317325
cls = cls or self._cls
@@ -388,6 +396,7 @@ def get_params_from_method_call(
388396
# (3) Parse the function call:
389397
for cls in self._cls.__mro__:
390398
source = self.get_source(cls)
399+
source = remove_comments(source)
391400
match = find_parenthesis_content(source, f'self.{func_name}')
392401
if match is not None:
393402
for i, kwarg in enumerate(split(match, sep=',')):
@@ -515,3 +524,12 @@ def split(content: str, sep: str) -> List[str]:
515524
if start != len(content): # Respect dangling `sep`:
516525
outs.append(content[start:].strip())
517526
return outs
527+
528+
529+
def remove_comments(content: str) -> str:
530+
content = re.sub(r'\s*#.*', '', content)
531+
content = re.sub(re.compile(r'r"""(.*?)"""', re.DOTALL), '', content)
532+
content = re.sub(re.compile(r'"""(.*?)"""', re.DOTALL), '', content)
533+
content = re.sub(re.compile(r"r'''(.*?)'''", re.DOTALL), '', content)
534+
content = re.sub(re.compile(r"'''(.*?)'''", re.DOTALL), '', content)
535+
return content

torch_geometric/io/fs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,10 @@ def cp(
174174
if use_cache and clear_cache and cache_dir is not None:
175175
try:
176176
rm(cache_dir)
177-
except PermissionError: # FIXME
177+
except Exception: # FIXME
178178
# Windows test yield "PermissionError: The process cannot access
179-
# the file because it is being used by another process"
179+
# the file because it is being used by another process".
180+
# Users may also observe "OSError: Directory not empty".
180181
# This is a quick workaround until we figure out the deeper issue.
181182
pass
182183

torch_geometric/metrics/link_pred.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from abc import ABC, abstractmethod
21
from typing import Optional, Tuple, Union
32

43
import torch
@@ -15,7 +14,7 @@
1514
BaseMetric = torch.nn.Module # type: ignore
1615

1716

18-
class LinkPredMetric(BaseMetric, ABC):
17+
class LinkPredMetric(BaseMetric):
1918
r"""An abstract class for computing link prediction retrieval metrics.
2019
2120
Args:
@@ -117,7 +116,6 @@ def reset(self) -> None:
117116
self.accum.zero_()
118117
self.total.zero_()
119118

120-
@abstractmethod
121119
def _compute(self, pred_isin_mat: Tensor, y_count: Tensor) -> Tensor:
122120
r"""Compute the specific metric.
123121
To be implemented separately for each metric class.

torch_geometric/nn/conv/edge_updater.jinja

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import torch_geometric.typing
88
from torch_geometric import is_compiling
99
from torch_geometric.utils import is_sparse
1010
from torch_geometric.typing import Size, SparseTensor
11-
11+
{% for module in modules %}
1212
from {{module}} import *
13+
{%- endfor %}
1314

1415

1516
{% include "collect.jinja" %}

torch_geometric/nn/conv/message_passing.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -166,13 +166,13 @@ def __init__(
166166
jinja_prefix = f'{self.__module__}_{self.__class__.__name__}'
167167
# Optimize `propagate()` via `*.jinja` templates:
168168
if not self.propagate.__module__.startswith(jinja_prefix):
169-
if self.inspector.can_read_source:
169+
try:
170170
module = module_from_template(
171171
module_name=f'{jinja_prefix}_propagate',
172172
template_path=osp.join(root_dir, 'propagate.jinja'),
173173
tmp_dirname='message_passing',
174174
# Keyword arguments:
175-
module=self.__module__,
175+
module=self.inspector._modules,
176176
collect_name='collect',
177177
signature=self._get_propagate_signature(),
178178
collect_param_dict=self.inspector.get_flat_param_dict(
@@ -185,34 +185,40 @@ def __init__(
185185
fuse=self.fuse,
186186
)
187187

188-
# Cache to potentially disable later on:
189188
self.__class__._orig_propagate = self.__class__.propagate
190189
self.__class__._jinja_propagate = module.propagate
191190

192191
self.__class__.propagate = module.propagate
193192
self.__class__.collect = module.collect
194-
else:
193+
except Exception: # pragma: no cover
195194
self.__class__._orig_propagate = self.__class__.propagate
196195
self.__class__._jinja_propagate = self.__class__.propagate
197196

198197
# Optimize `edge_updater()` via `*.jinja` templates (if implemented):
199198
if (self.inspector.implements('edge_update')
200-
and not self.edge_updater.__module__.startswith(jinja_prefix)
201-
and self.inspector.can_read_source):
202-
module = module_from_template(
203-
module_name=f'{jinja_prefix}_edge_updater',
204-
template_path=osp.join(root_dir, 'edge_updater.jinja'),
205-
tmp_dirname='message_passing',
206-
# Keyword arguments:
207-
module=self.__module__,
208-
collect_name='edge_collect',
209-
signature=self._get_edge_updater_signature(),
210-
collect_param_dict=self.inspector.get_param_dict(
211-
'edge_update'),
212-
)
199+
and not self.edge_updater.__module__.startswith(jinja_prefix)):
200+
try:
201+
module = module_from_template(
202+
module_name=f'{jinja_prefix}_edge_updater',
203+
template_path=osp.join(root_dir, 'edge_updater.jinja'),
204+
tmp_dirname='message_passing',
205+
# Keyword arguments:
206+
modules=self.inspector._modules,
207+
collect_name='edge_collect',
208+
signature=self._get_edge_updater_signature(),
209+
collect_param_dict=self.inspector.get_param_dict(
210+
'edge_update'),
211+
)
212+
213+
self.__class__._orig_edge_updater = self.__class__.edge_updater
214+
self.__class__._jinja_edge_updater = module.edge_updater
213215

214-
self.__class__.edge_updater = module.edge_updater
215-
self.__class__.edge_collect = module.edge_collect
216+
self.__class__.edge_updater = module.edge_updater
217+
self.__class__.edge_collect = module.edge_collect
218+
except Exception: # pragma: no cover
219+
self.__class__._orig_edge_updater = self.__class__.edge_updater
220+
self.__class__._jinja_edge_updater = (
221+
self.__class__.edge_updater)
216222

217223
# Explainability:
218224
self._explain: Optional[bool] = None

torch_geometric/nn/conv/propagate.jinja

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@ import torch_geometric.typing
88
from torch_geometric import is_compiling
99
from torch_geometric.utils import is_sparse
1010
from torch_geometric.typing import Size, SparseTensor
11-
11+
{% for module in modules %}
1212
from {{module}} import *
13+
{%- endfor %}
1314

1415

1516
{% include "collect.jinja" %}

0 commit comments

Comments
 (0)