Skip to content

Commit ecfa240

Browse files
committed
compiler: add padded_dimensions property
1 parent be6e4e2 commit ecfa240

3 files changed

Lines changed: 25 additions & 24 deletions

File tree

devito/passes/clusters/buffering.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,7 @@ def generate_buffers(clusters, key, sregistry, options, **kwargs):
376376
assert len(buffers) == 1, "Unexpected form of multi-level buffering"
377377
buffer, = buffers
378378
xd = buffer.indices[dim]
379-
# The new buffer is fed by `buffer`, so it inherits its padding
380-
# policy regardless of `f`'s
379+
# The new buffer is derived from `buffer`, so it inherits its padding policy
381380
extra_kwargs = {'is_autopaddable': buffer.is_autopaddable}
382381
else:
383382
size = infer_buffer_size(f, dim, clusters)

devito/passes/iet/linearization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def key1(f, d):
7070
if f.is_regular:
7171
# For paddable objects the following holds:
7272
# `same dim + same halo + same padding_dtype => same (auto-)padding`
73-
pad_key = f.__padding_dtype__ if d is f.dimensions[-1] else None
73+
pad_key = f.__padding_dtype__ if d in f._padded_dimensions else None
7474

7575
return (d, f._size_halo[d], pad_key)
7676
else:

devito/types/basic.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -907,36 +907,26 @@ def __padding_dtype__(self):
907907
return np.float32
908908

909909
def __padding_setup_smart__(self, **kwargs):
910-
nopadding = ((0, 0),)*self.ndim
911-
912-
if not self.__padding_dtype__:
913-
return nopadding
914-
915-
# The padded Dimension
916-
if not self.space_dimensions:
917-
return nopadding
918-
d = self.space_dimensions[-1]
910+
padding = [(0, 0)]*self.ndim
919911

920-
# Last space Dimension is not the most inner Dimension
921-
if d != self.dimensions[-1]:
922-
return nopadding
912+
if not self.__padding_dtype__ or not self._padded_dimensions:
913+
return tuple(padding)
923914

924915
mmts = configuration['platform'].max_mem_trans_size(self.__padding_dtype__)
925916

926-
snp = self._size_nopad[d]
927-
remainder = snp % mmts
928-
if remainder == 0:
929-
# Already a multiple of `mmts`, no need to pad
930-
return nopadding
931-
else:
917+
for d in self._padded_dimensions:
918+
snp = self._size_nopad[d]
919+
remainder = snp % mmts
920+
if remainder == 0:
921+
# Already a multiple of `mmts`, no need to pad
922+
continue
923+
932924
from devito.symbolics import RoundUp # noqa
933925
v = RoundUp(snp, mmts) - snp
934926
if v.is_Integer:
935927
v = int(v)
936928

937-
dpadding = (0, v)
938-
padding = [(0, 0)]*self.ndim
939-
padding[self.dimensions.index(d)] = dpadding
929+
padding[self.dimensions.index(d)] = (0, v)
940930

941931
return tuple(padding)
942932

@@ -987,6 +977,18 @@ def dimensions(self):
987977
"""Tuple of Dimensions representing the object indices."""
988978
return DimensionTuple(*self._dimensions, getters=self._dimensions)
989979

980+
@property
981+
def _padded_dimensions(self):
982+
try:
983+
d = self.space_dimensions[-1]
984+
except IndexError:
985+
return ()
986+
987+
if d is self.dimensions[-1]:
988+
return (d,)
989+
else:
990+
return ()
991+
990992
@cached_property
991993
def space_dimensions(self):
992994
"""Tuple of Dimensions defining the physical space."""

0 commit comments

Comments
 (0)