Skip to content
Open
2 changes: 2 additions & 0 deletions keras/src/backend/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from keras.src.backend.jax.core import shape
from keras.src.backend.jax.core import stop_gradient
from keras.src.backend.jax.core import vectorized_map
from keras.src.backend.jax.nn import adaptive_avg_pool
from keras.src.backend.jax.nn import adaptive_max_pool
from keras.src.backend.jax.rnn import cudnn_ok
from keras.src.backend.jax.rnn import gru
from keras.src.backend.jax.rnn import lstm
Expand Down
365 changes: 365 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1464,3 +1464,368 @@ def _pair(x):
# ---- reshape -> (N, C*kH*kW, L) ----
_, CKK, oH, oW = patches.shape
return patches.reshape(N, CKK, oH * oW)


def get_static_window_sizes(input_dim, output_dim):
"""Calculate small and big window sizes for adaptive pooling."""
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
return small_window, big_window


def compute_static_gather_indices(input_dim, output_size, big_window):
"""Compute gather indices for Two-Pool Gather method."""
window_starts = jnp.floor(
(jnp.arange(output_size) * input_dim) / output_size
).astype(jnp.int32)

window_ends = jnp.ceil(
(jnp.arange(1, output_size + 1) * input_dim) / output_size
).astype(jnp.int32)

window_sizes = window_ends - window_starts
is_big_window = window_sizes == big_window

small_window = big_window - 1
small_pool_len = input_dim - small_window + 1

small_indices = window_starts
big_indices = window_starts + small_pool_len

gather_indices = jnp.where(is_big_window, big_indices, small_indices)
return gather_indices.astype(jnp.int32)


# ---------- 1D Adaptive Pooling ----------
def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid"
)
small_pool_l = small_pool_l / small_l

big_pool_l = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid"
)
big_pool_l = big_pool_l / big_l

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 1D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size,)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC

n, l, c = inputs.shape
out_l = output_size[0]

small_l, big_l = get_static_window_sizes(l, out_l)
gather_l = compute_static_gather_indices(l, out_l, big_l)

small_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid"
)
big_pool_l = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid"
)

combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1)
pooled_l = jnp.take(combined_l, gather_l, axis=1)

if data_format == "channels_first":
pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL

return pooled_l


# ---------- 2D Adaptive Pooling ----------
def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 2D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC

n, h, w, c = inputs.shape
out_h, out_w = output_size

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid"
)
big_pool_h = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid"
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1)
pooled_h = jnp.take(combined_h, gather_h, axis=1)

small_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid"
)
big_pool_w = lax.reduce_window(
pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid"
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2)
pooled_w = jnp.take(combined_w, gather_w, axis=2)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW

return pooled_w


# ---------- 3D Adaptive Pooling ----------
def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Average Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_d = small_pool_d / small_d

big_pool_d = lax.reduce_window(
inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_d = big_pool_d / big_d

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_h = small_pool_h / small_h

big_pool_h = lax.reduce_window(
pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_h = big_pool_h / big_h

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid"
)
small_pool_w = small_pool_w / small_w

big_pool_w = lax.reduce_window(
pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid"
)
big_pool_w = big_pool_w / big_w

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"):
"""Adaptive Max Pooling 3D using Two-Pool Gather method."""
if isinstance(output_size, int):
output_size = (output_size, output_size, output_size)

if data_format == "channels_first":
inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC

n, d, h, w, c = inputs.shape
out_d, out_h, out_w = output_size

small_d, big_d = get_static_window_sizes(d, out_d)
gather_d = compute_static_gather_indices(d, out_d, big_d)

small_h, big_h = get_static_window_sizes(h, out_h)
gather_h = compute_static_gather_indices(h, out_h, big_h)

small_w, big_w = get_static_window_sizes(w, out_w)
gather_w = compute_static_gather_indices(w, out_w, big_w)

small_pool_d = lax.reduce_window(
inputs,
-jnp.inf,
lax.max,
(1, small_d, 1, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_d = lax.reduce_window(
inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid"
)

combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1)
pooled_d = jnp.take(combined_d, gather_d, axis=1)

small_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, small_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_h = lax.reduce_window(
pooled_d,
-jnp.inf,
lax.max,
(1, 1, big_h, 1, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2)
pooled_h = jnp.take(combined_h, gather_h, axis=2)

small_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, small_w, 1),
(1, 1, 1, 1, 1),
"valid",
)
big_pool_w = lax.reduce_window(
pooled_h,
-jnp.inf,
lax.max,
(1, 1, 1, big_w, 1),
(1, 1, 1, 1, 1),
"valid",
)

combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3)
pooled_w = jnp.take(combined_w, gather_w, axis=3)

if data_format == "channels_first":
pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW

return pooled_w


# ---------- Updated Dispatcher ----------
def adaptive_avg_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive average pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_avg_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_avg_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_avg_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_avg_pool supports 1D, 2D, or 3D inputs only."
)


def adaptive_max_pool(inputs, output_size, data_format="channels_first"):
"""Dispatcher for adaptive max pooling (1D, 2D, or 3D)."""
ndims = inputs.ndim - 2
if ndims == 1:
return adaptive_max_pool1d(inputs, output_size, data_format)
elif ndims == 2:
return adaptive_max_pool2d(inputs, output_size, data_format)
elif ndims == 3:
return adaptive_max_pool3d(inputs, output_size, data_format)
else:
raise ValueError(
"adaptive_max_pool supports 1D, 2D, or 3D inputs only."
)
16 changes: 16 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,3 +1237,19 @@ def _pair(x):

# ---- reshape -> (N, C*kH*kW, L) ----
return patches.reshape(N, C * k[0] * k[1], -1)


def adaptive_max_pool(inputs, output_size, data_format=None):
"""Adaptive max pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)


def adaptive_avg_pool(inputs, output_size, data_format=None):
"""Adaptive average pooling - Numpy backend not yet supported."""
raise NotImplementedError(
"Adaptive pooling not implemented for Numpy. "
"Use JAX, Torch or Tensorflow backend."
)
Loading
Loading