Skip to content
Open
2 changes: 2 additions & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
152 changes: 152 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,3 +1464,155 @@ def _pair(x):
# ---- reshape -> (N, C*kH*kW, L) ----
_, CKK, oH, oW = patches.shape
return patches.reshape(N, CKK, oH * oW)


def _adaptive_pool_start_index(output_idx, output_size, input_size):
"""Calculate start index for adaptive pooling (PyTorch compatible)."""
return jnp.floor((output_idx * input_size) / output_size).astype(jnp.int32)


def _adaptive_pool_end_index(output_idx, output_size, input_size):
"""Calculate end index for adaptive pooling (PyTorch compatible)."""
return jnp.ceil(((output_idx + 1) * input_size) / output_size).astype(
jnp.int32
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper functions _adaptive_pool_start_index and _adaptive_pool_end_index are defined but not used. This dead code should be removed to improve code clarity.



def adaptive_avg_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive average pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply average pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Average over spatial dimensions (axis 1, 2)
pooled = jnp.mean(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Average over spatial dimensions (axis 2, 3)
pooled = jnp.mean(region, axis=(2, 3))

result_list.append(pooled)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of adaptive pooling uses Python for loops to iterate over output positions. This is an anti-pattern in JAX as it prevents JIT compilation and leads to very poor performance, especially for larger inputs or output sizes. The computation should be expressed using JAX's vectorized operations or JIT-compatible loops like lax.fori_loop to achieve good performance. A fully vectorized einsum-based approach for average pooling, or a lax.fori_loop over output pixels for both pooling types, would be significantly more performant. This comment also applies to the adaptive_max_pool implementation.


# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output


def adaptive_max_pool(
inputs, output_size, data_format="channels_last", name=None
):
"""
Adaptive max pooling for JAX backend (PyTorch-compatible).
"""
# Convert output_size to tuple
spatial_dims = inputs.ndim - 2
if isinstance(output_size, int):
output_size = (output_size,) * spatial_dims
else:
output_size = tuple(output_size)

# Get spatial shape
if data_format == "channels_last":
batch_size = inputs.shape[0]
channels = inputs.shape[-1]
spatial_shape = inputs.shape[1:-1]
else: # channels_first
batch_size = inputs.shape[0]
channels = inputs.shape[1]
spatial_shape = inputs.shape[2:]

if len(output_size) != 2:
raise NotImplementedError(
"Only 2D adaptive pooling is currently supported"
)

out_h, out_w = output_size
in_h, in_w = spatial_shape

# Build output by iterating over output positions
result_list = []

for i in range(out_h):
for j in range(out_w):
# Calculate pooling region for this output position
start_h = jnp.floor((i * in_h) / out_h).astype(jnp.int32)
end_h = jnp.ceil(((i + 1) * in_h) / out_h).astype(jnp.int32)
start_w = jnp.floor((j * in_w) / out_w).astype(jnp.int32)
end_w = jnp.ceil(((j + 1) * in_w) / out_w).astype(jnp.int32)

# Extract region and apply max pooling
if data_format == "channels_last":
region = inputs[:, start_h:end_h, start_w:end_w, :]
# Max over spatial dimensions (axis 1, 2)
pooled = jnp.max(region, axis=(1, 2))
else: # channels_first
region = inputs[:, :, start_h:end_h, start_w:end_w]
# Max over spatial dimensions (axis 2, 3)
pooled = jnp.max(region, axis=(2, 3))

result_list.append(pooled)

# Stack results: (out_h*out_w, batch, channels)
output = jnp.stack(result_list, axis=0)

# Reshape and transpose to correct output shape
if data_format == "channels_last":
# (out_h*out_w, batch, channels) -> (batch, out_h, out_w, channels)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 0, 1, 3))
else: # channels_first
# (out_h*out_w, batch, channels) -> (batch, channels, out_h, out_w)
output = output.reshape(out_h, out_w, batch_size, channels)
output = jnp.transpose(output, (2, 3, 0, 1))

return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions adaptive_avg_pool and adaptive_max_pool are nearly identical, with the only difference being the pooling operation (jnp.mean vs jnp.max). This code duplication can be avoided by creating a generic _adaptive_pool helper function that takes the pooling function as an argument. This would improve maintainability and reduce redundancy.

For example:

def _adaptive_pool(inputs, output_size, data_format, pool_op):
    # ... common setup code ...
    for i in range(out_h):
        for j in range(out_w):
            # ... common region calculation ...
            if data_format == "channels_last":
                region = inputs[:, start_h:end_h, start_w:end_w, :]
                pooled = pool_op(region, axis=(1, 2))
            else:  # channels_first
                region = inputs[:, :, start_h:end_h, start_w:end_w]
                pooled = pool_op(region, axis=(2, 3))
            result_list.append(pooled)
    # ... common reshape and transpose code ...
    return output

def adaptive_avg_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.mean)

def adaptive_max_pool(inputs, output_size, data_format="channels_last", name=None):
    # ...
    return _adaptive_pool(inputs, output_size, data_format, jnp.max)

Note that this refactoring suggestion still contains the performance issue mentioned in another comment. The primary goal here is to illustrate how to reduce code duplication.

4 changes: 4 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
SpectralNormalization,
)
from keras.src.layers.normalization.unit_normalization import UnitNormalization
from keras.src.layers.pooling.adaptive_average_pooling2d import (
AdaptiveAveragePooling2D,
)
from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
from keras.src.layers.pooling.average_pooling1d import AveragePooling1D
from keras.src.layers.pooling.average_pooling2d import AveragePooling2D
from keras.src.layers.pooling.average_pooling3d import AveragePooling3D
Expand Down
4 changes: 4 additions & 0 deletions keras/src/layers/pooling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from keras.src.layers.pooling.adaptive_average_pooling2d import (
AdaptiveAveragePooling2D,
)
from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D
112 changes: 112 additions & 0 deletions keras/src/layers/pooling/adaptive_average_pooling2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Adaptive Average Pooling 2D layer."""

from keras import config
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer


@keras_export("keras.layers.AdaptiveAveragePooling2D")
class AdaptiveAveragePooling2D(Layer):
"""Adaptive average pooling operation for 2D spatial data.

This layer applies an adaptive average pooling operation, which pools the
input such that the output has a target shape specified by `output_size`,
regardless of the input shape. The kernel size and stride are automatically
computed to achieve the target output size.

Args:
output_size: Integer or tuple of 2 integers, specifying the target
output size `(height, width)`. If a single integer is provided,
the same value is used for both dimensions.
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, height, width)`. Defaults to the value found in
your Keras config file at `~/.keras/keras.json`. If never set, then
"channels_last" will be used.

Input shape:
- If `data_format="channels_last"`:
4D tensor with shape `(batch_size, height, width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape `(batch_size, channels, height, width)`.

Output shape:
- If `data_format="channels_last"`:
4D tensor with shape
`(batch_size, output_height, output_width, channels)`.
- If `data_format="channels_first"`:
4D tensor with shape
`(batch_size, channels, output_height, output_width)`.

Examples:

>>> input_img = np.random.rand(1, 64, 64, 3)
>>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=(32, 32))
>>> output_img = layer(input_img)
>>> output_img.shape
(1, 32, 32, 3)

>>> # Single integer for square output
>>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=7)
>>> output_img = layer(input_img)
>>> output_img.shape
(1, 7, 7, 3)
"""

def __init__(self, output_size, data_format=None, **kwargs):
super().__init__(**kwargs)
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
elif isinstance(output_size, (list, tuple)):
if len(output_size) != 2:
raise ValueError(
f"`output_size` must be an integer or tuple of 2 integers. "
f"Received: output_size={output_size}"
)
self.output_size = tuple(output_size)
else:
raise TypeError(
f"`output_size` must be an integer or tuple of 2 integers. "
f"Received: output_size={output_size} of type "
f"{type(output_size)}"
)

self.data_format = data_format or config.image_data_format()

if self.data_format not in {"channels_first", "channels_last"}:
raise ValueError(
f"Invalid data_format: {self.data_format}. "
"Must be either 'channels_first' or 'channels_last'."
)

def call(self, inputs):
return ops.adaptive_avg_pool(
inputs, output_size=self.output_size, data_format=self.data_format
)

def compute_output_shape(self, input_shape):
if self.data_format == "channels_last":
return (
input_shape[0],
self.output_size[0],
self.output_size[1],
input_shape[3],
)
else: # channels_first
return (
input_shape[0],
input_shape[1],
self.output_size[0],
self.output_size[1],
)

def get_config(self):
config_dict = {
"output_size": self.output_size,
"data_format": self.data_format,
}
base_config = super().get_config()
return {**base_config, **config_dict}
Loading
Loading